Rebasing main's latest commits onto ravi/filter_support_rebased (#225)

- add code for two variants of filtered index, readme and CI tests

- add utils for synthetic label generation and CI tests.


* Add co-authors

Co-authored-by: ravishankar <rakri@microsoft.com>
Co-authored-by: Varun Sivashankar <t-varunsi@microsoft.com>

---------

Co-authored-by: ravishankar <rakri@microsoft.com>
Co-authored-by: David Kaczynski <dkaczynski@microsoft.com>
Co-authored-by: Siddharth Gollapudi <t-gollapudis@microsoft.com>
Co-authored-by: Neelam Mahapatro <nmahapatro@microsoft.com>
Co-authored-by: Harsha Vardhan Simhadri <harshasi@microsoft.com>
Co-authored-by: Harsha Vardhan Simhadri <harsha-simhadri@users.noreply.github.com>
Co-authored-by: REDMOND\patelyash <patelyash@microsoft.com>
Co-authored-by: Varun Sivashankar <t-varunsi@microsoft.com>
This commit is contained in:
David Kaczynski 2023-03-15 16:49:48 -04:00 коммит произвёл GitHub
Родитель 5ba6a5d2c2
Коммит 5ec769aa85
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
33 изменённых файлов: 4154 добавлений и 898 удалений

42
.github/workflows/pr-test.yml поставляемый
Просмотреть файл

@ -6,6 +6,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-2019, windows-latest]
@ -71,6 +72,9 @@ jobs:
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0
${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: Searching with fast_l2 distance function
if: runner.os != 'Windows'
run: |
${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
- name: build and search in-memory index with MIPS metric
run: |
@ -117,7 +121,6 @@ jobs:
- name: Generate 10K random int8 index vectors, 1K query vectors, in 10 dims and compute GT
run: |
run: |
${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
@ -154,10 +157,11 @@ jobs:
${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search an incremental index
run: |
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200;
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_1K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags
${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
- name: test a streaming index
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags
@ -165,6 +169,7 @@ jobs:
- name: Generate 10K random uint8 index vectors, 1K query vectors, in 10 dims and compute GT
if: success() || failure()
run: |
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
@ -172,44 +177,77 @@ jobs:
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn mips --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
- name: build and search in-memory index with L2 metrics
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search in-memory index with cosine metric
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn cosine --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_cosine_rand_uint8_10D_10K_norm50.0
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
- name: build and search disk index (one shot graph build, L2, no diskPQ)
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons)
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (sharded graph build, L2, no diskPQ)
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search disk index (one shot graph build, L2, diskPQ)
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
- name: build and search an incremental index
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200;
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_10K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
- name: test a streaming index
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
- name: Generate 10K random uint8 index vectors, 1K query vectors, 10K Label Points (50 unique labels), in 10 dims and compute GT
if: success() || failure()
run: |
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
${{ env.diskann_built_utils }}/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_10_10K.txt
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
- name: build and search in-memory index with labels using L2 metrics
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
- name: Build and search stitched vamana
if: success() || failure()
run: |
${{ env.diskann_built_tests }}/build_stitched_index --num_threads 48 --data_type uint8 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix ./stit_32_100_64_new --universal_label 0
${{ env.diskann_built_tests }}/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix ./stit_32_100_64_new --query_file ./rand_uint8_10D_1K_norm50.0.bin --result_path ./rand_stit_96_10_90_new --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 10 10 10 10 10 30 50 70 90 110 130 150 170 190 210 230 250 270 290 310 330 350 370 390 410
- uses: actions/setup-python@v3
- name: Install cibuildwheel
run: python -m pip install cibuildwheel==2.11.3

28
CMakeSettings.json Normal file
Просмотреть файл

@ -0,0 +1,28 @@
{
"configurations": [
{
"name": "x64-Release",
"generator": "Ninja",
"configurationType": "Release",
"inheritEnvironments": [ "msvc_x64" ],
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": ""
},
{
"name": "WSL-GCC-Release",
"generator": "Ninja",
"configurationType": "RelWithDebInfo",
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeExecutable": "cmake",
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": "",
"inheritEnvironments": [ "linux_x64" ],
"wslPath": "${defaultWSLPath}"
}
]
}

Просмотреть файл

@ -87,4 +87,17 @@ Please see the following pages on using the compiled code:
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
- To be added: Python interfaces and docker files
Please cite this software in your work as:
```
@misc{diskann-github,
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan}},
title = {{DiskANN: Scalable, efficient and Feature-rich ANNS}},
url = {https://github.com/Microsoft/DiskANN},
version = {0.5},
year = {2023}
}
```

Просмотреть файл

@ -40,7 +40,7 @@ namespace diskann {
const uint32_t WARMUP_L = 20;
const uint32_t NUM_KMEANS_REPS = 12;
template<typename T>
template<typename T, typename LabelT>
class PQFlashIndex;
DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str);
@ -68,38 +68,47 @@ namespace diskann {
uint64_t warmup_aligned_dim);
#endif
DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix,
const std::string &vamana_suffix,
const std::string &idmaps_prefix,
const std::string &idmaps_suffix,
const _u64 nshards, unsigned max_degree,
const std::string &output_vamana,
const std::string &medoids_file);
DISKANN_DLLEXPORT int merge_shards(
const std::string &vamana_prefix, const std::string &vamana_suffix,
const std::string &idmaps_prefix, const std::string &idmaps_suffix,
const _u64 nshards, unsigned max_degree, const std::string &output_vamana,
const std::string &medoids_file, bool use_filters = false,
const std::string &labels_to_medoids_file = std::string(""));
DISKANN_DLLEXPORT void extract_shard_labels(
const std::string &in_label_file, const std::string &shard_ids_bin,
const std::string &shard_label_file);
template<typename T>
DISKANN_DLLEXPORT std::string preprocess_base_file(
const std::string &infile, const std::string &indexPrefix,
diskann::Metric &distMetric);
template<typename T>
template<typename T, typename LabelT = uint32_t>
DISKANN_DLLEXPORT int build_merged_vamana_index(
std::string base_file, diskann::Metric _compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_file,
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters = false, const std::string &label_file = std::string(""),
const std::string &labels_to_medoids_file = std::string(""),
const std::string &universal_label = "", const _u32 Lf = 0);
template<typename T>
template<typename T, typename LabelT>
DISKANN_DLLEXPORT uint32_t optimize_beamwidth(
std::unique_ptr<diskann::PQFlashIndex<T>> &_pFlashIndex, T *tuning_sample,
_u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
uint32_t nthreads, uint32_t start_bw = 2);
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &_pFlashIndex,
T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim,
uint32_t L, uint32_t nthreads, uint32_t start_bw = 2);
template<typename T>
DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath,
const char *indexFilePath,
const char *indexBuildParameters,
diskann::Metric _compareMetric,
bool use_opq = false);
template<typename T, typename LabelT = uint32_t>
DISKANN_DLLEXPORT int build_disk_index(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric _compareMetric,
bool use_opq = false, bool use_filters = false,
const std::string &label_file =
std::string(""), // default is empty string for no label_file
const std::string &universal_label = "", const _u32 filter_threshold = 0,
const _u32 Lf = 0); // default is empty string for no universal label
template<typename T>
DISKANN_DLLEXPORT void create_disk_layout(

Просмотреть файл

@ -24,6 +24,7 @@
#define DEFAULT_MAXC 750
namespace diskann {
inline double estimate_ram_usage(_u64 size, _u32 dim, _u32 datasize,
_u32 degree) {
double size_of_data = ((double) size) * ROUND_UP(dim, 8) * datasize;
@ -60,7 +61,7 @@ namespace diskann {
}
};
template<typename T, typename TagT = uint32_t>
template<typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
class Index {
/**************************************************************************
*
@ -129,6 +130,17 @@ namespace diskann {
Parameters &parameters,
const std::vector<TagT> &tags);
// Filtered Support
DISKANN_DLLEXPORT void build_filtered_index(
const char *filename, const std::string &label_file,
const size_t num_points_to_load, Parameters &parameters,
const std::vector<TagT> &tags = std::vector<TagT>());
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
// Get converted integer label from string to int map (_label_map)
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);
// Set starting point of an index before inserting any points incrementally
DISKANN_DLLEXPORT void set_start_point(T *data);
// Set starting point to a random point on a sphere of certain radius
@ -155,6 +167,12 @@ namespace diskann {
float *distances,
std::vector<T *> &res_vectors);
// Filter support search
template<typename IndexType>
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(
const T *query, const LabelT &filter_label, const size_t K,
const unsigned L, IndexType *indices, float *distances);
// Will fail if tag already in the index or if tag=0.
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
@ -177,6 +195,8 @@ namespace diskann {
DISKANN_DLLEXPORT consolidation_report
consolidate_deletes(const Parameters &parameters);
DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters &parameters);
DISKANN_DLLEXPORT bool is_index_saved();
// repositions frozen points to the end of _data - if they have been moved
@ -208,8 +228,8 @@ namespace diskann {
protected:
// No copy/assign.
Index(const Index<T, TagT> &) = delete;
Index<T, TagT> &operator=(const Index<T, TagT> &) = delete;
Index(const Index<T, TagT, LabelT> &) = delete;
Index<T, TagT, LabelT> &operator=(const Index<T, TagT, LabelT> &) = delete;
// Use after _data and _nd have been populated
// Acquire exclusive _update_lock before calling
@ -223,14 +243,23 @@ namespace diskann {
// determines navigating node of the graph by calculating medoid of datafopt
unsigned calculate_entry_point();
void parse_label_file(const std::string &label_file,
size_t &num_pts_labels);
std::unordered_map<std::string, LabelT> load_label_map(
const std::string &map_file);
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(
const T *node_coords, const unsigned Lindex,
const std::vector<unsigned> &init_ids, InMemQueryScratch<T> *scratch,
bool use_filter, const std::vector<LabelT> &filters,
bool ret_frozen = true, bool search_invocation = false);
void search_for_point_and_prune(int location, _u32 Lindex,
std::vector<unsigned> &pruned_list,
InMemQueryScratch<T> *scratch);
InMemQueryScratch<T> *scratch,
bool use_filter = false,
_u32 filteredLindex = 0);
void prune_neighbors(const unsigned location, std::vector<Neighbor> &pool,
std::vector<unsigned> &pruned_list,
@ -342,6 +371,19 @@ namespace diskann {
bool _enable_tags = false;
bool _normalize_vecs = false; // Using normalied L2 for cosine.
// Filter Support
bool _filtered_index = false;
std::vector<std::vector<LabelT>> _pts_to_labels;
tsl::robin_set<LabelT> _labels;
std::string _labels_file;
std::unordered_map<LabelT, _u32> _label_to_medoid_id;
std::unordered_map<_u32, _u32> _medoid_counts;
bool _use_universal_label = false;
LabelT _universal_label = 0;
uint32_t _filterIndexingQueueSize;
std::unordered_map<std::string, LabelT> _label_map;
// Indexing parameters
uint32_t _indexingQueueSize;
uint32_t _indexingRange;

Просмотреть файл

@ -20,7 +20,7 @@
namespace diskann {
template<typename T>
template<typename T, typename LabelT = uint32_t>
class PQFlashIndex {
public:
DISKANN_DLLEXPORT PQFlashIndex(
@ -70,11 +70,26 @@ namespace diskann {
float *res_dists, const _u64 beam_width,
const bool use_reorder_data = false, QueryStats *stats = nullptr);
DISKANN_DLLEXPORT void cached_beam_search(
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
float *res_dists, const _u64 beam_width, const bool use_filter,
const LabelT &filter_label, const bool use_reorder_data = false,
QueryStats *stats = nullptr);
DISKANN_DLLEXPORT void cached_beam_search(
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
float *res_dists, const _u64 beam_width, const _u32 io_limit,
const bool use_reorder_data = false, QueryStats *stats = nullptr);
DISKANN_DLLEXPORT void cached_beam_search(
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
float *res_dists, const _u64 beam_width, const bool use_filter,
const LabelT &filter_label, const _u32 io_limit,
const bool use_reorder_data = false, QueryStats *stats = nullptr);
DISKANN_DLLEXPORT LabelT
get_converted_label(const std::string &filter_label);
DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range,
const _u64 min_l_search,
const _u64 max_l_search,
@ -94,12 +109,26 @@ namespace diskann {
DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads,
_u64 visited_reserve = 4096);
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
private:
DISKANN_DLLEXPORT inline bool point_has_label(_u32 point_id, _u32 label_id);
std::unordered_map<std::string, LabelT> load_label_map(
const std::string &map_file);
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file,
size_t &num_pts_labels);
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file,
_u32 &num_pts,
_u32 &num_total_labels);
DISKANN_DLLEXPORT inline int32_t get_filter_number(
const LabelT &filter_label);
// index info
// nhood of node `i` is in sector: [i / nnodes_per_sector]
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
// nnbrs of node `i`: *(unsigned*) (buf)
// nbrs of node `i`: ((unsigned*)buf) + 1
_u64 max_node_len = 0, nnodes_per_sector = 0, max_degree = 0;
// Data used for searching with re-order vectors
@ -171,6 +200,20 @@ namespace diskann {
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;
// filter support
_u32 *_pts_to_label_offsets = nullptr;
_u32 *_pts_to_labels = nullptr;
tsl::robin_set<LabelT> _labels;
std::unordered_map<LabelT, _u32> _filter_to_medoid_id;
bool _use_universal_label;
_u32 _universal_filter_num;
std::vector<LabelT> _filter_list;
tsl::robin_set<_u32> _dummy_pts;
tsl::robin_set<_u32> _has_dummy_pts;
tsl::robin_map<_u32, _u32> _dummy_to_real_map;
tsl::robin_map<_u32, std::vector<_u32>> _real_to_dummy_map;
std::unordered_map<std::string, LabelT> _label_map;
#ifdef EXEC_ENV_OLS
// Set to a larger value than the actual header to accommodate
// any additions we make to the header. This is an outer limit

Просмотреть файл

@ -31,19 +31,24 @@ typedef int FileHandle;
#include "memory_mapped_files.h"
#endif
#include <unordered_map>
#include <sstream>
#include <iostream>
// taken from
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
// round up X to the nearest multiple of Y
#define ROUND_UP(X, Y) \
((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y))
((((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0)) * (Y))
#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0))
#define DIV_ROUND_UP(X, Y) \
(((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0))
// round down X to the nearest multiple of Y
#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y))
#define ROUND_DOWN(X, Y) (((uint64_t) (X) / (Y)) * (Y))
// alignment tests
#define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0)
#define IS_ALIGNED(X, Y) ((uint64_t) (X) % (uint64_t) (Y) == 0)
#define IS_512_ALIGNED(X) IS_ALIGNED(X, 512)
#define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096)
#define METADATA_SIZE \
@ -92,8 +97,9 @@ typedef uint16_t _u16;
typedef int16_t _s16;
typedef uint8_t _u8;
typedef int8_t _s8;
inline void open_file_to_write(std::ofstream& writer,
const std::string& filename) {
inline void open_file_to_write(std::ofstream& writer,
const std::string& filename) {
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
if (!file_exists(filename))
writer.open(filename, std::ios::binary | std::ios::out);
@ -144,6 +150,48 @@ inline int delete_file(const std::string& fileName) {
}
}
inline void convert_labels_string_to_int(const std::string& inFileName,
const std::string& outFileName,
const std::string& mapFileName,
const std::string& unv_label) {
std::unordered_map<std::string, _u32> string_int_map;
std::ofstream label_writer(outFileName);
std::ifstream label_reader(inFileName);
if (unv_label != "")
string_int_map[unv_label] = 0;
std::string line, token;
while (std::getline(label_reader, line)) {
std::istringstream new_iss(line);
std::vector<_u32> lbls;
while (getline(new_iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
if (string_int_map.find(token) == string_int_map.end()) {
_u32 nextId = (_u32) string_int_map.size() + 1;
string_int_map[token] = nextId;
}
lbls.push_back(string_int_map[token]);
}
if (lbls.size() <= 0) {
std::cout << "No label found";
exit(-1);
}
for (size_t j = 0; j < lbls.size(); j++) {
if (j != lbls.size() - 1)
label_writer << lbls[j] << ",";
else
label_writer << lbls[j] << std::endl;
}
}
label_writer.close();
std::ofstream map_writer(mapFileName);
for (auto mp : string_int_map) {
map_writer << mp.first << "\t" << mp.second << std::endl;
}
map_writer.close();
}
#ifdef EXEC_ENV_OLS
class AlignedFileReader;
#endif
@ -568,6 +616,19 @@ namespace diskann {
}
#endif
inline void copy_file(std::string in_file, std::string out_file) {
std::ifstream source(in_file, std::ios::binary);
std::ofstream dest(out_file, std::ios::binary);
std::istreambuf_iterator<char> begin_source(source);
std::istreambuf_iterator<char> end_source;
std::ostreambuf_iterator<char> begin_dest(dest);
std::copy(begin_source, end_source, begin_dest);
source.close();
dest.close();
}
DISKANN_DLLEXPORT double calculate_recall(
unsigned num_queries, unsigned* gold_std, float* gs_dist, unsigned dim_gs,
unsigned* our_results, unsigned dim_or, unsigned recall_at);
@ -947,7 +1008,7 @@ inline void normalize(T* arr, size_t dim) {
}
sum = sqrt(sum);
for (uint32_t i = 0; i < dim; i++) {
arr[i] = (T)(arr[i] / sum);
arr[i] = (T) (arr[i] / sum);
}
}

Просмотреть файл

@ -25,13 +25,12 @@ PYBIND11_MAKE_OPAQUE(std::vector<float>);
PYBIND11_MAKE_OPAQUE(std::vector<int8_t>);
PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
namespace py = pybind11;
using namespace diskann;
template<class T>
struct DiskANNIndex {
PQFlashIndex<T> * pq_flash_index;
PQFlashIndex<T> *pq_flash_index;
std::shared_ptr<AlignedFileReader> reader;
DiskANNIndex(diskann::Metric metric) {
@ -74,8 +73,7 @@ struct DiskANNIndex {
const size_t num_nodes_to_cache, int cache_mechanism) {
const std::string index_path =
index_path_prefix + std::string("_disk.index");
int load_success =
pq_flash_index->load(num_threads, index_path.c_str());
int load_success = pq_flash_index->load(num_threads, index_path.c_str());
if (load_success != 0) {
return load_success;
}
@ -197,8 +195,8 @@ struct DiskANNIndex {
const int num_threads) {
py::array_t<unsigned> offsets(num_queries + 1);
std::vector<std::vector<_u64> > u64_ids(num_queries);
std::vector<std::vector<float> > dists(num_queries);
std::vector<std::vector<_u64>> u64_ids(num_queries);
std::vector<std::vector<float>> dists(num_queries);
auto offsets_mutable = offsets.mutable_unchecked();
offsets_mutable(0) = 0;
@ -246,8 +244,9 @@ struct DiskANNIndex {
stats, num_queries,
[](const diskann::QueryStats &stats) { return stats.n_cmps; });
delete[] stats;
return std::make_pair(std::make_pair(offsets, std::make_pair(ids, res_dists)),
collective_stats);
return std::make_pair(
std::make_pair(offsets, std::make_pair(ids, res_dists)),
collective_stats);
}
};
@ -259,16 +258,15 @@ PYBIND11_MODULE(diskannpy, m) {
m.attr("__version__") = "dev";
#endif
py::bind_vector<std::vector<unsigned> >(m, "VectorUnsigned");
py::bind_vector<std::vector<float> >(m, "VectorFloat");
py::bind_vector<std::vector<int8_t> >(m, "VectorInt8");
py::bind_vector<std::vector<uint8_t> >(m, "VectorUInt8");
py::bind_vector<std::vector<unsigned>>(m, "VectorUnsigned");
py::bind_vector<std::vector<float>>(m, "VectorFloat");
py::bind_vector<std::vector<int8_t>>(m, "VectorInt8");
py::bind_vector<std::vector<uint8_t>>(m, "VectorUInt8");
py::enum_<Metric>(m, "Metric")
.value("L2", Metric::L2)
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
.export_values();
.value("L2", Metric::L2)
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
.export_values();
py::class_<Parameters>(m, "Parameters")
.def(py::init<>())
@ -294,9 +292,11 @@ PYBIND11_MODULE(diskannpy, m) {
py::class_<AlignedFileReader>(m, "AlignedFileReader");
#ifdef _WINDOWS
py::class_<WindowsAlignedFileReader>(m, "WindowsAlignedFileReader").def(py::init<>());
py::class_<WindowsAlignedFileReader>(m, "WindowsAlignedFileReader")
.def(py::init<>());
#else
py::class_<LinuxAlignedFileReader>(m, "LinuxAlignedFileReader").def(py::init<>());
py::class_<LinuxAlignedFileReader>(m, "LinuxAlignedFileReader")
.def(py::init<>());
#endif
m.def(
@ -327,7 +327,7 @@ PYBIND11_MODULE(diskannpy, m) {
[](const std::string &path, std::vector<unsigned> &ids,
std::vector<float> &distances) {
unsigned *id_ptr = nullptr;
float * dist_ptr = nullptr;
float *dist_ptr = nullptr;
size_t num, dims;
load_truthset(path, id_ptr, dist_ptr, num, dims);
// TODO: Remove redundant copies.
@ -349,7 +349,7 @@ PYBIND11_MODULE(diskannpy, m) {
const unsigned ground_truth_dims, std::vector<unsigned> &results,
const unsigned result_dims, const unsigned recall_at) {
unsigned *gti_ptr = ground_truth_ids.data();
float * gtd_ptr = ground_truth_dists.data();
float *gtd_ptr = ground_truth_dists.data();
unsigned *r_ptr = results.data();
double total_recall = 0;
@ -390,10 +390,10 @@ PYBIND11_MODULE(diskannpy, m) {
std::vector<float> &ground_truth_dists,
const unsigned ground_truth_dims,
py::array_t<unsigned, py::array::c_style | py::array::forcecast>
& results,
&results,
const unsigned result_dims, const unsigned recall_at) {
unsigned *gti_ptr = ground_truth_ids.data();
float * gtd_ptr = ground_truth_dists.data();
float *gtd_ptr = ground_truth_dists.data();
unsigned *r_ptr = results.mutable_data();
double total_recall = 0;
@ -434,9 +434,9 @@ PYBIND11_MODULE(diskannpy, m) {
size_t dims) { save_bin<_u32>(file_name, data.data(), npts, dims); },
py::arg("file_name"), py::arg("data"), py::arg("npts"), py::arg("dims"));
py::class_<DiskANNIndex<float> >(m, "DiskANNFloatIndex")
py::class_<DiskANNIndex<float>>(m, "DiskANNFloatIndex")
.def(py::init([](diskann::Metric metric) {
return std::unique_ptr<DiskANNIndex<float> >(
return std::unique_ptr<DiskANNIndex<float>>(
new DiskANNIndex<float>(metric));
}))
.def("cache_bfs_levels", &DiskANNIndex<float>::cache_bfs_levels,
@ -462,8 +462,8 @@ PYBIND11_MODULE(diskannpy, m) {
.def("batch_range_search_numpy_input",
&DiskANNIndex<float>::batch_range_search_numpy_input,
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
py::arg("num_threads"))
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<float> &self, const char *data_file_path,
@ -485,13 +485,13 @@ PYBIND11_MODULE(diskannpy, m) {
py::arg("indexing_ram_limit"), py::arg("num_threads"),
py::arg("pq_disk_bytes") = 0);
py::class_<DiskANNIndex<int8_t> >(m, "DiskANNInt8Index")
py::class_<DiskANNIndex<int8_t>>(m, "DiskANNInt8Index")
.def(py::init([](diskann::Metric metric) {
return std::unique_ptr<DiskANNIndex<int8_t> >(
return std::unique_ptr<DiskANNIndex<int8_t>>(
new DiskANNIndex<int8_t>(metric));
}))
.def("cache_bfs_levels", &DiskANNIndex<int8_t>::cache_bfs_levels,
py::arg("num_nodes_to_cache"))
py::arg("num_nodes_to_cache"))
.def("load_index", &DiskANNIndex<int8_t>::load_index,
py::arg("index_path_prefix"), py::arg("num_threads"),
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
@ -513,8 +513,8 @@ PYBIND11_MODULE(diskannpy, m) {
.def("batch_range_search_numpy_input",
&DiskANNIndex<int8_t>::batch_range_search_numpy_input,
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
py::arg("num_threads"))
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<int8_t> &self, const char *data_file_path,
@ -536,10 +536,9 @@ PYBIND11_MODULE(diskannpy, m) {
py::arg("indexing_ram_limit"), py::arg("num_threads"),
py::arg("pq_disk_bytes") = 0);
py::class_<DiskANNIndex<uint8_t> >(m, "DiskANNUInt8Index")
py::class_<DiskANNIndex<uint8_t>>(m, "DiskANNUInt8Index")
.def(py::init([](diskann::Metric metric) {
return std::unique_ptr<DiskANNIndex<uint8_t> >(
return std::unique_ptr<DiskANNIndex<uint8_t>>(
new DiskANNIndex<uint8_t>(metric));
}))
.def("cache_bfs_levels", &DiskANNIndex<uint8_t>::cache_bfs_levels,
@ -565,8 +564,8 @@ PYBIND11_MODULE(diskannpy, m) {
.def("batch_range_search_numpy_input",
&DiskANNIndex<uint8_t>::batch_range_search_numpy_input,
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
py::arg("num_threads"))
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
py::arg("beam_width"), py::arg("num_threads"))
.def(
"build",
[](DiskANNIndex<uint8_t> &self, const char *data_file_path,

Просмотреть файл

@ -242,7 +242,8 @@ namespace diskann {
const std::string &idmaps_prefix,
const std::string &idmaps_suffix, const _u64 nshards,
unsigned max_degree, const std::string &output_vamana,
const std::string &medoids_file) {
const std::string &medoids_file, bool use_filters,
const std::string &labels_to_medoids_file) {
// Read ID maps
std::vector<std::string> vamana_names(nshards);
std::vector<std::vector<unsigned>> idmaps(nshards);
@ -283,6 +284,57 @@ namespace diskann {
});
diskann::cout << "Finished computing node -> shards map" << std::endl;
// will merge all the labels to medoids files of each shard into one
// combined file
if (use_filters) {
std::unordered_map<unsigned, std::vector<_u32>> global_label_to_medoids;
for (_u64 i = 0; i < nshards; i++) {
std::ifstream mapping_reader;
std::string map_file = vamana_names[i] + "_labels_to_medoids.txt";
mapping_reader.open(map_file);
std::string line, token;
unsigned line_cnt = 0;
while (std::getline(mapping_reader, line)) {
std::istringstream iss(line);
_u32 cnt = 0;
_u32 medoid;
_u32 label;
while (std::getline(iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'),
token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'),
token.end());
unsigned token_as_num = std::stoul(token);
if (cnt == 0)
label = token_as_num;
else
medoid = token_as_num;
cnt++;
}
global_label_to_medoids[label].push_back(idmaps[i][medoid]);
line_cnt++;
}
mapping_reader.close();
}
std::ofstream mapping_writer(labels_to_medoids_file);
assert(mapping_writer.is_open());
for (auto iter : global_label_to_medoids) {
mapping_writer << iter.first << ", ";
auto &vec = iter.second;
for (_u32 idx = 0; idx < vec.size() - 1; idx++) {
mapping_writer << vec[idx] << ", ";
}
mapping_writer << vec[vec.size() - 1] << std::endl;
}
mapping_writer.close();
}
// create cached vamana readers
std::vector<cached_ifstream> vamana_readers(nshards);
for (_u64 i = 0; i < nshards; i++) {
@ -384,10 +436,16 @@ namespace diskann {
}
// read from shard_id ifstream
vamana_readers[shard_id].read((char *) &shard_nnbrs, sizeof(unsigned));
std::vector<unsigned> shard_nhood(shard_nnbrs);
vamana_readers[shard_id].read((char *) shard_nhood.data(),
shard_nnbrs * sizeof(unsigned));
if (shard_nnbrs == 0) {
diskann::cout << "WARNING: shard #" << shard_id << ", node_id "
<< node_id << " has 0 nbrs" << std::endl;
}
std::vector<unsigned> shard_nhood(shard_nnbrs);
if (shard_nnbrs > 0)
vamana_readers[shard_id].read((char *) shard_nhood.data(),
shard_nnbrs * sizeof(unsigned));
// rename nodes
for (_u64 j = 0; j < shard_nnbrs; j++) {
if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) {
@ -402,8 +460,10 @@ namespace diskann {
nnbrs = (unsigned) (std::min)(final_nhood.size(), (uint64_t) max_degree);
// write into merged ofstream
merged_vamana_writer.write((char *) &nnbrs, sizeof(unsigned));
merged_vamana_writer.write((char *) final_nhood.data(),
nnbrs * sizeof(unsigned));
if (nnbrs > 0) {
merged_vamana_writer.write((char *) final_nhood.data(),
nnbrs * sizeof(unsigned));
}
merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned));
for (auto &p : final_nhood)
nhood_set[p] = 0;
@ -418,47 +478,204 @@ namespace diskann {
return 0;
}
// TODO: Make this a streaming implementation to avoid exceeding the memory
// budget
/* If the number of filters per point N exceeds the graph degree R,
then it is difficult to have edges to all labels from this point.
This function break up such dense points to have only a threshold of maximum
T labels per point  It divides one graph nodes to multiple nodes and append
the new nodes at the end. The dummy map contains the real graph id of the
new nodes added to the graph */
template<typename T>
int build_merged_vamana_index(std::string base_file,
diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate,
double ram_budget, std::string mem_index_path,
std::string medoids_file,
std::string centroids_file,
size_t build_pq_bytes, bool use_opq) {
void breakup_dense_points(const std::string data_file,
const std::string labels_file, _u32 density,
const std::string out_data_file,
const std::string out_labels_file,
const std::string out_metadata_file) {
std::string token, line;
std::ifstream labels_stream(labels_file);
T *data;
_u64 npts, ndims;
diskann::load_bin<T>(data_file, data, npts, ndims);
std::unordered_map<_u32, _u32> dummy_pt_ids;
_u32 next_dummy_id = (_u32) npts;
_u32 point_cnt = 0;
std::vector<std::vector<uint32_t>> labels_per_point;
labels_per_point.resize(npts);
_u32 dense_pts = 0;
if (labels_stream.is_open()) {
while (getline(labels_stream, line)) {
std::stringstream iss(line);
_u32 lbl_cnt = 0;
_u32 label_host = point_cnt;
while (getline(iss, token, ',')) {
if (lbl_cnt == density) {
if (label_host == point_cnt)
dense_pts++;
label_host = next_dummy_id;
labels_per_point.resize(next_dummy_id + 1);
dummy_pt_ids[next_dummy_id] = (_u32) point_cnt;
next_dummy_id++;
lbl_cnt = 0;
}
token.erase(std::remove(token.begin(), token.end(), '\n'),
token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'),
token.end());
unsigned token_as_num = std::stoul(token);
labels_per_point[label_host].push_back(token_as_num);
lbl_cnt++;
}
point_cnt++;
}
}
diskann::cout << "fraction of dense points with >= " << density
<< " labels = " << (float) dense_pts / (float) npts
<< std::endl;
if (labels_per_point.size() != 0) {
diskann::cout << labels_per_point.size() << " is the new number of points"
<< std::endl;
std::ofstream label_writer(out_labels_file);
assert(label_writer.is_open());
for (_u32 i = 0; i < labels_per_point.size(); i++) {
for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) {
label_writer << labels_per_point[i][j] << ",";
}
if (labels_per_point[i].size() != 0)
label_writer << labels_per_point[i][labels_per_point[i].size() - 1];
label_writer << std::endl;
}
label_writer.close();
}
if (dummy_pt_ids.size() != 0) {
diskann::cout << dummy_pt_ids.size()
<< " is the number of dummy points created" << std::endl;
data = (T *) std::realloc((void *) data,
labels_per_point.size() * ndims * sizeof(T));
std::ofstream dummy_writer(out_metadata_file);
assert(dummy_writer.is_open());
for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) {
dummy_writer << i->first << "," << i->second << std::endl;
std::memcpy(data + i->first * ndims, data + i->second * ndims,
ndims * sizeof(T));
}
dummy_writer.close();
}
diskann::save_bin<T>(out_data_file, data, labels_per_point.size(), ndims);
}
void extract_shard_labels(
const std::string &in_label_file, const std::string &shard_ids_bin,
const std::string &shard_label_file) { // assumes ith row is for ith
// point in labels file
diskann::cout << "Extracting labels for shard" << std::endl;
_u32 *ids = nullptr;
_u64 num_ids, tmp_dim;
diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim);
_u32 counter = 0, shard_counter = 0;
std::string cur_line;
std::ifstream label_reader(in_label_file);
std::ofstream label_writer(shard_label_file);
assert(label_reader.is_open());
assert(label_reader.is_open());
if (label_reader && label_writer) {
while (std::getline(label_reader, cur_line)) {
if (shard_counter >= num_ids) {
break;
}
if (counter == ids[shard_counter]) {
label_writer << cur_line << "\n";
shard_counter++;
}
counter++;
}
}
if (ids != nullptr)
delete[] ids;
}
template<typename T, typename LabelT>
int build_merged_vamana_index(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_file,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf) {
size_t base_num, base_dim;
diskann::get_bin_metadata(base_file, base_num, base_dim);
double full_index_ram =
estimate_ram_usage(base_num, base_dim, sizeof(T), R);
// TODO: Make this honest when there is filter support
if (full_index_ram < ram_budget * 1024 * 1024 * 1024) {
diskann::cout << "Full index fits in RAM budget, should consume at most "
<< full_index_ram / (1024 * 1024 * 1024)
<< "GiBs, so building in one shot" << std::endl;
diskann::Parameters paras;
paras.Set<unsigned>("L", (unsigned) L);
paras.Set<unsigned>("Lf", (unsigned) Lf);
paras.Set<unsigned>("R", (unsigned) R);
paras.Set<unsigned>("C", 750);
paras.Set<float>("alpha", 1.2f);
paras.Set<unsigned>("num_rnds", 2);
paras.Set<bool>("saturate_graph", 1);
if (!use_filters)
paras.Set<bool>("saturate_graph", 1);
else
paras.Set<bool>("saturate_graph", 0);
using TagT = uint32_t;
paras.Set<std::string>("save_path", mem_index_path);
std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
compareMetric, base_dim, base_num, false, false, false,
build_pq_bytes > 0, build_pq_bytes, use_opq));
_pvamanaIndex->build(base_file.c_str(), base_num, paras);
std::unique_ptr<diskann::Index<T, TagT, LabelT>> _pvamanaIndex =
std::unique_ptr<diskann::Index<T, TagT, LabelT>>(
new diskann::Index<T, TagT, LabelT>(
compareMetric, base_dim, base_num, false, false, false,
build_pq_bytes > 0, build_pq_bytes, use_opq));
if (!use_filters)
_pvamanaIndex->build(base_file.c_str(), base_num, paras);
else {
if (universal_label != "") { // indicates no universal label
LabelT unv_label_as_num = 0;
_pvamanaIndex->set_universal_label(unv_label_as_num);
}
_pvamanaIndex->build_filtered_index(base_file.c_str(), label_file,
base_num, paras);
}
_pvamanaIndex->save(mem_index_path.c_str());
if (use_filters) {
// need to copy the labels_to_medoids file to the specified input file
std::remove(labels_to_medoids_file.c_str());
std::string mem_labels_to_medoid_file =
mem_index_path + "_labels_to_medoids.txt";
copy_file(mem_labels_to_medoid_file, labels_to_medoids_file);
std::remove(mem_labels_to_medoid_file.c_str());
}
std::remove(medoids_file.c_str());
std::remove(centroids_file.c_str());
return 0;
}
// where the universal label is to be saved in the final graph
std::string final_index_universal_label_file =
mem_index_path + "_universal_label.txt";
std::string merged_index_prefix = mem_index_path + "_tempFiles";
Timer timer;
int num_parts =
int num_parts =
partition_with_ram_budget<T>(base_file, sampling_rate, ram_budget,
2 * R / 3, merged_index_prefix, 2);
diskann::cout << timer.elapsed_seconds_for_step("partitioning data")
@ -475,6 +692,9 @@ namespace diskann {
std::string shard_ids_file = merged_index_prefix + "_subshard-" +
std::to_string(p) + "_ids_uint32.bin";
std::string shard_labels_file = merged_index_prefix + "_subshard-" +
std::to_string(p) + "_labels.txt";
retrieve_shard_data_from_ids<T>(base_file, shard_ids_file,
shard_base_file);
@ -483,6 +703,7 @@ namespace diskann {
diskann::Parameters paras;
paras.Set<unsigned>("L", L);
paras.Set<unsigned>("Lf", Lf);
paras.Set<unsigned>("R", (2 * (R / 3)));
paras.Set<unsigned>("C", 750);
paras.Set<float>("alpha", 1.2f);
@ -496,17 +717,43 @@ namespace diskann {
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
compareMetric, shard_base_dim, shard_base_pts, false, false,
false, build_pq_bytes > 0, build_pq_bytes, use_opq));
_pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras);
if (!use_filters) {
_pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras);
} else {
diskann::extract_shard_labels(label_file, shard_ids_file,
shard_labels_file);
if (universal_label != "") { // indicates no universal label
LabelT unv_label_as_num = 0;
_pvamanaIndex->set_universal_label(unv_label_as_num);
}
_pvamanaIndex->build_filtered_index(
shard_base_file.c_str(), shard_labels_file, shard_base_pts, paras);
}
_pvamanaIndex->save(shard_index_file.c_str());
// copy universal label file from first shard to the final destination
// index, since all shards anyway share the universal label
if (p == 0) {
std::string shard_universal_label_file =
shard_index_file + "_universal_label.txt";
if (universal_label != "") {
copy_file(shard_universal_label_file,
final_index_universal_label_file);
}
}
std::remove(shard_base_file.c_str());
}
diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") << std::endl;
diskann::cout << timer.elapsed_seconds_for_step(
"building indices on shards")
<< std::endl;
timer.reset();
diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index",
merged_index_prefix + "_subshard-", "_ids_uint32.bin",
num_parts, R, mem_index_path, medoids_file);
diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl;
num_parts, R, mem_index_path, medoids_file,
use_filters, labels_to_medoids_file);
diskann::cout << timer.elapsed_seconds_for_step("merging indices")
<< std::endl;
// delete tempFiles
for (int p = 0; p < num_parts; p++) {
@ -514,6 +761,8 @@ namespace diskann {
merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
std::string shard_id_file = merged_index_prefix + "_subshard-" +
std::to_string(p) + "_ids_uint32.bin";
std::string shard_labels_file = merged_index_prefix + "_subshard-" +
std::to_string(p) + "_labels.txt";
std::string shard_index_file =
merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
std::string shard_index_file_data = shard_index_file + ".data";
@ -522,6 +771,17 @@ namespace diskann {
std::remove(shard_id_file.c_str());
std::remove(shard_index_file.c_str());
std::remove(shard_index_file_data.c_str());
if (use_filters) {
std::string shard_index_label_file = shard_index_file + "_labels.txt";
std::string shard_index_univ_label_file =
shard_index_file + "_universal_label.txt";
std::string shard_index_label_map_file =
shard_index_file + "_labels_to_medoids.txt";
std::remove(shard_labels_file.c_str());
std::remove(shard_index_label_file.c_str());
std::remove(shard_index_label_map_file.c_str());
std::remove(shard_index_univ_label_file.c_str());
}
}
return 0;
}
@ -530,11 +790,11 @@ namespace diskann {
// optimizes the beamwidth to maximize QPS for a given L_search subject to
// 99.9 latency not blowing up
template<typename T>
template<typename T, typename LabelT>
uint32_t optimize_beamwidth(
std::unique_ptr<diskann::PQFlashIndex<T>> &pFlashIndex, T *tuning_sample,
_u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
uint32_t nthreads, uint32_t start_bw) {
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim,
uint32_t L, uint32_t nthreads, uint32_t start_bw) {
uint32_t cur_bw = start_bw;
double max_qps = 0;
uint32_t best_bw = start_bw;
@ -799,10 +1059,13 @@ namespace diskann {
<< std::endl;
}
template<typename T>
template<typename T, typename LabelT>
int build_disk_index(const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters,
diskann::Metric compareMetric, bool use_opq) {
diskann::Metric compareMetric, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &universal_label,
const _u32 filter_threshold, const _u32 Lf) {
std::stringstream parser;
parser << std::string(indexBuildParameters);
std::string cur_param;
@ -863,7 +1126,9 @@ namespace diskann {
std::string base_file(dataFilePath);
std::string data_file_to_use = base_file;
std::string labels_file_original = label_file;
std::string index_prefix_path(indexFilePath);
std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt";
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
std::string pq_compressed_vectors_path =
index_prefix_path + "_pq_compressed.bin";
@ -871,6 +1136,19 @@ namespace diskann {
std::string disk_index_path = index_prefix_path + "_disk.index";
std::string medoids_path = disk_index_path + "_medoids.bin";
std::string centroids_path = disk_index_path + "_centroids.bin";
std::string labels_to_medoids_path =
disk_index_path + "_labels_to_medoids.txt";
std::string mem_labels_file = mem_index_path + "_labels.txt";
std::string disk_labels_file = disk_index_path + "_labels.txt";
std::string mem_univ_label_file = mem_index_path + "_universal_label.txt";
std::string disk_univ_label_file = disk_index_path + "_universal_label.txt";
std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt";
std::string dummy_remap_file =
disk_index_path +
"_dummy_remap.txt"; // remap will be used if we break-up points of high
// label-density to create copies
std::string sample_base_prefix = index_prefix_path + "_sample";
// optional, used if disk index file must store pq data
std::string disk_pq_pivots_path =
@ -930,6 +1208,27 @@ namespace diskann {
auto s = std::chrono::high_resolution_clock::now();
// If there is filter support, we break-up points which have too many labels
// into replica dummy points which evenly distribute the filters. The rest
// of index build happens on the augmented base and labels
std::string augmented_data_file, augmented_labels_file;
if (use_filters) {
convert_labels_string_to_int(labels_file_original, labels_file_to_use,
disk_labels_int_map_file, universal_label);
augmented_data_file = index_prefix_path + "_augmented_data.bin";
augmented_labels_file = index_prefix_path + "_augmented_labels.txt";
if (filter_threshold != 0) {
dummy_remap_file = index_prefix_path + "_dummy_remap.txt";
breakup_dense_points<T>(
data_file_to_use, labels_file_to_use, filter_threshold,
augmented_data_file, augmented_labels_file,
dummy_remap_file); // RKNOTE: This has large memory footprint, need
// to make this streaming
data_file_to_use = augmented_data_file;
labels_file_to_use = augmented_labels_file;
}
}
size_t points_num, dim;
Timer timer;
@ -956,7 +1255,8 @@ namespace diskann {
generate_quantized_data<T>(data_file_to_use, pq_pivots_path,
pq_compressed_vectors_path, compareMetric, p_val,
num_pq_chunks, use_opq);
diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl;
diskann::cout << timer.elapsed_seconds_for_step("generating quantized data")
<< std::endl;
// Gopal. Splitting diskann_dll into separate DLLs for search and build.
// This code should only be available in the "build" DLL.
@ -966,10 +1266,11 @@ namespace diskann {
#endif
timer.reset();
diskann::build_merged_vamana_index<T>(
diskann::build_merged_vamana_index<T, LabelT>(
data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
build_pq_bytes, use_opq);
build_pq_bytes, use_opq, use_filters, labels_file_to_use,
labels_to_medoids_path, universal_label, Lf);
diskann::cout << timer.elapsed_seconds_for_step(
"building merged vamana index")
<< std::endl;
@ -997,6 +1298,17 @@ namespace diskann {
double sample_sampling_rate = num_sample_points / points_num;
gen_random_slice<T>(data_file_to_use.c_str(), sample_base_prefix,
sample_sampling_rate);
if (use_filters) {
copy_file(labels_file_to_use, disk_labels_file);
std::remove(mem_labels_file.c_str());
if (universal_label != "") {
copy_file(mem_univ_label_file, disk_univ_label_file);
std::remove(mem_univ_label_file.c_str());
}
std::remove(augmented_data_file.c_str());
std::remove(augmented_labels_file.c_str());
std::remove(labels_file_to_use.c_str());
}
std::remove(mem_index_path.c_str());
if (use_disk_pq)
@ -1041,48 +1353,123 @@ namespace diskann {
uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim);
#endif
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t>(
std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint32_t>(
std::unique_ptr<diskann::PQFlashIndex<int8_t, uint32_t>> &pFlashIndex,
int8_t *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t>(
std::unique_ptr<diskann::PQFlashIndex<uint8_t>> &pFlashIndex,
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint32_t>(
std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint32_t>> &pFlashIndex,
uint8_t *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float>(
std::unique_ptr<diskann::PQFlashIndex<float>> &pFlashIndex,
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint32_t>(
std::unique_ptr<diskann::PQFlashIndex<float, uint32_t>> &pFlashIndex,
float *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT int build_disk_index<int8_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq);
template DISKANN_DLLEXPORT int build_disk_index<uint8_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq);
template DISKANN_DLLEXPORT int build_disk_index<float>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq);
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint16_t>(
std::unique_ptr<diskann::PQFlashIndex<int8_t, uint16_t>> &pFlashIndex,
int8_t *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint16_t>(
std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint16_t>> &pFlashIndex,
uint8_t *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint16_t>(
std::unique_ptr<diskann::PQFlashIndex<float, uint16_t>> &pFlashIndex,
float *tuning_sample, _u64 tuning_sample_num,
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
uint32_t start_bw);
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t>(
template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint32_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint32_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
template DISKANN_DLLEXPORT int build_disk_index<float, uint32_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
// LabelT = uint16
template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint16_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint16_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
template DISKANN_DLLEXPORT int build_disk_index<float, uint16_t>(
const char *dataFilePath, const char *indexFilePath,
const char *indexBuildParameters, diskann::Metric compareMetric,
bool use_opq, bool use_filters, const std::string &label_file,
const std::string &universal_label, const _u32 filter_threshold,
const _u32 Lf);
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint32_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
template DISKANN_DLLEXPORT int build_merged_vamana_index<float>(
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint32_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t>(
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint32_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
// Label=16_t
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint16_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint16_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint16_t>(
std::string base_file, diskann::Metric compareMetric, unsigned L,
unsigned R, double sampling_rate, double ram_budget,
std::string mem_index_path, std::string medoids_path,
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
bool use_filters, const std::string &label_file,
const std::string &labels_to_medoids_file,
const std::string &universal_label, const _u32 Lf);
}; // namespace diskann

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -29,7 +29,7 @@
// block size for reading/ processing large files and matrices in blocks
#define BLOCK_SIZE 5000000
//#define SAVE_INFLATED_PQ true
// #define SAVE_INFLATED_PQ true
template<typename T>
void gen_random_slice(const std::string base_file,
@ -591,9 +591,9 @@ int partition_with_ram_budget(const std::string data_file,
train_dim, k_base, cluster_sizes);
for (auto &p : cluster_sizes) {
p = (_u64) (p /
sampling_rate); // to account for the fact that p is the size
// of the shard over the testing sample.
// to account for the fact that p is the size of the shard over the
// testing sample.
p = (_u64) (p / sampling_rate);
double cur_shard_ram_estimate =
diskann::estimate_ram_usage(p, train_dim, sizeof(T), graph_degree);

Просмотреть файл

@ -41,9 +41,9 @@
namespace diskann {
template<typename T>
PQFlashIndex<T>::PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader,
diskann::Metric m)
template<typename T, typename LabelT>
PQFlashIndex<T, LabelT>::PQFlashIndex(
std::shared_ptr<AlignedFileReader> &fileReader, diskann::Metric m)
: reader(fileReader), metric(m) {
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) {
if (std::is_floating_point<T>::value) {
@ -63,8 +63,8 @@ namespace diskann {
this->dist_cmp_float.reset(diskann::get_distance_function<float>(metric));
}
template<typename T>
PQFlashIndex<T>::~PQFlashIndex() {
template<typename T, typename LabelT>
PQFlashIndex<T, LabelT>::~PQFlashIndex() {
#ifndef EXEC_ENV_OLS
if (data != nullptr) {
delete[] data;
@ -86,10 +86,18 @@ namespace diskann {
this->reader->deregister_all_threads();
reader->close();
}
if (_pts_to_label_offsets != nullptr) {
delete[] _pts_to_label_offsets;
}
if (_pts_to_labels != nullptr) {
delete[] _pts_to_labels;
}
}
template<typename T>
void PQFlashIndex<T>::setup_thread_data(_u64 nthreads, _u64 visited_reserve) {
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::setup_thread_data(_u64 nthreads,
_u64 visited_reserve) {
diskann::cout << "Setting up thread-specific contexts for nthreads: "
<< nthreads << std::endl;
// omp parallel for to generate unique thread IDs
@ -107,8 +115,9 @@ namespace diskann {
load_flag = true;
}
template<typename T>
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) {
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::load_cache_list(
std::vector<uint32_t> &node_list) {
diskann::cout << "Loading the cache list into memory.." << std::flush;
_u64 num_cached_nodes = node_list.size();
@ -180,14 +189,14 @@ namespace diskann {
}
#ifdef EXEC_ENV_OLS
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
#else
template<typename T>
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, uint32_t nthreads,
std::vector<uint32_t> &node_list) {
@ -245,10 +254,10 @@ namespace diskann {
diskann::aligned_free(samples);
}
template<typename T>
void PQFlashIndex<T>::cache_bfs_levels(_u64 num_nodes_to_cache,
std::vector<uint32_t> &node_list,
const bool shuffle) {
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::cache_bfs_levels(
_u64 num_nodes_to_cache, std::vector<uint32_t> &node_list,
const bool shuffle) {
std::random_device rng;
std::mt19937 urng(rng());
@ -379,8 +388,8 @@ namespace diskann {
diskann::cout << "done" << std::endl;
}
template<typename T>
void PQFlashIndex<T>::use_medoids_data_as_centroids() {
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::use_medoids_data_as_centroids() {
if (centroid_data != nullptr)
aligned_free(centroid_data);
alloc_aligned(((void **) &centroid_data),
@ -425,13 +434,170 @@ namespace diskann {
}
}
template<typename T, typename LabelT>
inline int32_t PQFlashIndex<T, LabelT>::get_filter_number(
const LabelT &filter_label) {
int idx = -1;
for (_u32 i = 0; i < _filter_list.size(); i++) {
if (_filter_list[i] == filter_label) {
idx = i;
break;
}
}
return idx;
}
template<typename T, typename LabelT>
std::unordered_map<std::string, LabelT>
PQFlashIndex<T, LabelT>::load_label_map(const std::string &labels_map_file) {
std::unordered_map<std::string, LabelT> string_to_int_mp;
std::ifstream map_reader(labels_map_file);
std::string line, token;
LabelT token_as_num;
std::string label_str;
while (std::getline(map_reader, line)) {
std::istringstream iss(line);
getline(iss, token, '\t');
label_str = token;
getline(iss, token, '\t');
token_as_num = std::stoul(token);
string_to_int_mp[label_str] = token_as_num;
}
return string_to_int_mp;
}
template<typename T, typename LabelT>
LabelT PQFlashIndex<T, LabelT>::get_converted_label(
const std::string &filter_label) {
if (_label_map.find(filter_label) != _label_map.end()) {
return _label_map[filter_label];
}
std::stringstream stream;
stream << "Unable to find label in the Label Map";
diskann::cerr << stream.str() << std::endl;
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
__LINE__);
exit(-1);
}
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::get_label_file_metadata(
std::string map_file, _u32 &num_pts, _u32 &num_total_labels) {
std::ifstream infile(map_file);
std::string line, token;
num_pts = 0;
num_total_labels = 0;
while (std::getline(infile, line)) {
std::istringstream iss(line);
while (getline(iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
num_total_labels++;
}
num_pts++;
}
diskann::cout << "Labels file metadata: num_points: " << num_pts
<< ", #total_labels: " << num_total_labels << std::endl;
infile.close();
}
template<typename T, typename LabelT>
inline bool PQFlashIndex<T, LabelT>::point_has_label(_u32 point_id,
_u32 label_id) {
_u32 start_vec = _pts_to_label_offsets[point_id];
_u32 num_lbls = _pts_to_labels[start_vec];
bool ret_val = false;
for (_u32 i = 0; i < num_lbls; i++) {
if (_pts_to_labels[start_vec + 1 + i] == label_id) {
ret_val = true;
break;
}
}
return ret_val;
}
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file,
size_t &num_points_labels) {
std::ifstream infile(label_file);
if (infile.fail()) {
throw diskann::ANNException(
std::string("Failed to open file ") + label_file, -1);
}
std::string line, token;
_u32 line_cnt = 0;
_u32 num_pts_in_label_file;
_u32 num_total_labels;
get_label_file_metadata(label_file, num_pts_in_label_file,
num_total_labels);
_pts_to_label_offsets = new _u32[num_pts_in_label_file];
_pts_to_labels = new _u32[num_pts_in_label_file + num_total_labels];
_u32 counter = 0;
while (std::getline(infile, line)) {
std::istringstream iss(line);
std::vector<_u32> lbls(0);
_pts_to_label_offsets[line_cnt] = counter;
_u32 &num_lbls_in_cur_pt = _pts_to_labels[counter];
num_lbls_in_cur_pt = 0;
counter++;
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
LabelT token_as_num = std::stoul(token);
if (_labels.find(token_as_num) == _labels.end()) {
_filter_list.emplace_back(token_as_num);
}
int32_t filter_num = get_filter_number(token_as_num);
if (filter_num == -1) {
diskann::cout << "Error!! " << std::endl;
exit(-1);
}
_pts_to_labels[counter++] = filter_num;
num_lbls_in_cur_pt++;
_labels.insert(token_as_num);
}
if (num_lbls_in_cur_pt == 0) {
diskann::cout << "No label found for point " << line_cnt << std::endl;
exit(-1);
}
line_cnt++;
}
infile.close();
num_points_labels = line_cnt;
}
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label) {
int32_t temp_filter_num = get_filter_number(label);
if (temp_filter_num == -1) {
diskann::cout << "Error, could not find universal label. Exitting."
<< std::endl;
exit(-1);
} else {
_use_universal_label = true;
_universal_filter_num = (_u32) temp_filter_num;
}
}
#ifdef EXEC_ENV_OLS
template<typename T>
int PQFlashIndex<T>::load(MemoryMappedFiles &files, uint32_t num_threads,
const char *index_prefix) {
template<typename T, typename LabelT>
int PQFlashIndex<T, LabelT>::load(MemoryMappedFiles &files,
uint32_t num_threads,
const char *index_prefix) {
#else
template<typename T>
int PQFlashIndex<T>::load(uint32_t num_threads, const char *index_prefix) {
template<typename T, typename LabelT>
int PQFlashIndex<T, LabelT>::load(uint32_t num_threads,
const char *index_prefix) {
#endif
std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin";
std::string pq_compressed_vectors =
@ -449,14 +615,14 @@ namespace diskann {
}
#ifdef EXEC_ENV_OLS
template<typename T>
int PQFlashIndex<T>::load_from_separate_paths(
template<typename T, typename LabelT>
int PQFlashIndex<T, LabelT>::load_from_separate_paths(
diskann::MemoryMappedFiles &files, uint32_t num_threads,
const char *index_filepath, const char *pivots_filepath,
const char *compressed_filepath) {
#else
template<typename T>
int PQFlashIndex<T>::load_from_separate_paths(
template<typename T, typename LabelT>
int PQFlashIndex<T, LabelT>::load_from_separate_paths(
uint32_t num_threads, const char *index_filepath,
const char *pivots_filepath, const char *compressed_filepath) {
#endif
@ -467,6 +633,15 @@ namespace diskann {
std::string centroids_file =
std::string(disk_index_file) + "_centroids.bin";
std::string labels_file = std ::string(disk_index_file) + "_labels.txt";
std::string labels_to_medoids =
std ::string(disk_index_file) + "_labels_to_medoids.txt";
std::string dummy_map_file =
std ::string(disk_index_file) + "_dummy_map.txt";
std::string labels_map_file =
std ::string(disk_index_file) + "_labels_map.txt";
size_t num_pts_in_label_file = 0;
size_t pq_file_dim, pq_file_num_centroids;
#ifdef EXEC_ENV_OLS
get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim,
@ -503,6 +678,77 @@ namespace diskann {
this->num_points = npts_u64;
this->n_chunks = nchunks_u64;
if (file_exists(labels_file)) {
parse_label_file(labels_file, num_pts_in_label_file);
assert(num_pts_in_label_file == this->num_points);
_label_map = load_label_map(labels_map_file);
if (file_exists(labels_to_medoids)) {
std::ifstream medoid_stream(labels_to_medoids);
assert(medoid_stream.is_open());
std::string line, token;
_filter_to_medoid_id.clear();
try {
while (std::getline(medoid_stream, line)) {
std::istringstream iss(line);
_u32 cnt = 0;
_u32 medoid;
LabelT label;
while (std::getline(iss, token, ',')) {
if (cnt == 0)
label = std::stoul(token);
else
medoid = (_u32) stoul(token);
cnt++;
}
_filter_to_medoid_id[label] = medoid;
}
} catch (std::system_error &e) {
throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__,
__LINE__);
}
}
std::string univ_label_file =
std ::string(disk_index_file) + "_universal_label.txt";
if (file_exists(univ_label_file)) {
std::ifstream universal_label_reader(univ_label_file);
assert(universal_label_reader.is_open());
std::string univ_label;
universal_label_reader >> univ_label;
universal_label_reader.close();
LabelT label_as_num = std::stoul(univ_label);
set_universal_label(label_as_num);
}
if (file_exists(dummy_map_file)) {
std::ifstream dummy_map_stream(dummy_map_file);
assert(dummy_map_stream.is_open());
std::string line, token;
while (std::getline(dummy_map_stream, line)) {
std::istringstream iss(line);
_u32 cnt = 0;
_u32 dummy_id;
_u32 real_id;
while (std::getline(iss, token, ',')) {
if (cnt == 0)
dummy_id = (_u32) stoul(token);
else
real_id = (_u32) stoul(token);
cnt++;
}
_dummy_pts.insert(dummy_id);
_has_dummy_pts.insert(real_id);
_dummy_to_real_map[dummy_id] = real_id;
if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end())
_real_to_dummy_map[real_id] = std::vector<_u32>();
_real_to_dummy_map[real_id].emplace_back(dummy_id);
}
dummy_map_stream.close();
diskann::cout << "Loaded dummy map" << std::endl;
}
}
#ifdef EXEC_ENV_OLS
pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64);
@ -541,8 +787,8 @@ namespace diskann {
disk_bytes_per_point =
disk_pq_n_chunks *
sizeof(_u8); // revising disk_bytes_per_point since DISK PQ is used.
std::cout << "Disk index uses PQ data compressed down to "
<< disk_pq_n_chunks << " bytes per point." << std::endl;
diskann::cout << "Disk index uses PQ data compressed down to "
<< disk_pq_n_chunks << " bytes per point." << std::endl;
}
// read index metadata
@ -704,8 +950,8 @@ namespace diskann {
float *norm_val;
diskann::load_bin<float>(norm_file, norm_val, dumr, dumc);
this->max_base_norm = norm_val[0];
std::cout << "Setting re-scaling factor of base vectors to "
<< this->max_base_norm << std::endl;
diskann::cout << "Setting re-scaling factor of base vectors to "
<< this->max_base_norm << std::endl;
delete[] norm_val;
}
diskann::cout << "done.." << std::endl;
@ -742,23 +988,58 @@ namespace diskann {
}
#endif
template<typename T>
void PQFlashIndex<T>::cached_beam_search(const T *query1, const _u64 k_search,
const _u64 l_search, _u64 *indices,
float *distances,
const _u64 beam_width,
const bool use_reorder_data,
QueryStats *stats) {
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::cached_beam_search(
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
float *distances, const _u64 beam_width, const bool use_reorder_data,
QueryStats *stats) {
cached_beam_search(query1, k_search, l_search, indices, distances,
beam_width, std::numeric_limits<_u32>::max(),
use_reorder_data, stats);
}
template<typename T>
void PQFlashIndex<T>::cached_beam_search(
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::cached_beam_search(
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
float *distances, const _u64 beam_width, const bool use_filter,
const LabelT &filter_label, const bool use_reorder_data,
QueryStats *stats) {
cached_beam_search(query1, k_search, l_search, indices, distances,
beam_width, use_filter, filter_label,
std::numeric_limits<_u32>::max(), use_reorder_data,
stats);
}
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::cached_beam_search(
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
float *distances, const _u64 beam_width, const _u32 io_limit,
const bool use_reorder_data, QueryStats *stats) {
LabelT dummy_filter = 0;
cached_beam_search(query1, k_search, l_search, indices, distances,
beam_width, false, dummy_filter,
std::numeric_limits<_u32>::max(), use_reorder_data,
stats);
}
template<typename T, typename LabelT>
void PQFlashIndex<T, LabelT>::cached_beam_search(
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
float *distances, const _u64 beam_width, const bool use_filter,
const LabelT &filter_label, const _u32 io_limit,
const bool use_reorder_data, QueryStats *stats) {
int32_t filter_num = 0;
if (use_filter) {
filter_num = get_filter_number(filter_label);
if (filter_num < 0) {
if (!_use_universal_label) {
return;
} else {
filter_num = _universal_filter_num;
}
}
}
if (beam_width > MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS",
-1, __FUNCSIG__, __FILE__, __LINE__);
@ -838,14 +1119,22 @@ namespace diskann {
_u32 best_medoid = 0;
float best_dist = (std::numeric_limits<float>::max)();
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
float cur_expanded_dist = dist_cmp_float->compare(
query_float, centroid_data + aligned_dim * cur_m,
(unsigned) aligned_dim);
if (cur_expanded_dist < best_dist) {
best_medoid = medoids[cur_m];
best_dist = cur_expanded_dist;
if (!use_filter) {
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
float cur_expanded_dist = dist_cmp_float->compare(
query_float, centroid_data + aligned_dim * cur_m,
(unsigned) aligned_dim);
if (cur_expanded_dist < best_dist) {
best_medoid = medoids[cur_m];
best_dist = cur_expanded_dist;
}
}
} else if (_filter_to_medoid_id.find(filter_label) !=
_filter_to_medoid_id.end()) {
best_medoid = _filter_to_medoid_id[filter_label];
} else {
throw ANNException("Cannot find medoid for specified filter.", -1,
__FUNCSIG__, __FILE__, __LINE__);
}
compute_dists(&best_medoid, 1, dist_scratch);
@ -963,6 +1252,12 @@ namespace diskann {
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];
if (visited.insert(id).second) {
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
continue;
if (use_filter && !point_has_label(id, filter_num) &&
!point_has_label(id, _universal_filter_num))
continue;
cmps++;
float dist = dist_scratch[m];
Neighbor nn(id, dist);
@ -1024,6 +1319,12 @@ namespace diskann {
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];
if (visited.insert(id).second) {
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
continue;
if (use_filter && !point_has_label(id, filter_num) &&
!point_has_label(id, _universal_filter_num))
continue;
cmps++;
float dist = dist_scratch[m];
if (stats != nullptr) {
@ -1096,6 +1397,11 @@ namespace diskann {
// copy k_search values
for (_u64 i = 0; i < k_search; i++) {
indices[i] = full_retset[i].id;
if (_dummy_pts.find(indices[i]) != _dummy_pts.end()) {
indices[i] = _dummy_to_real_map[indices[i]];
}
if (distances != nullptr) {
distances[i] = full_retset[i].distance;
if (metric == diskann::Metric::INNER_PRODUCT) {
@ -1121,14 +1427,12 @@ namespace diskann {
// range search returns results of all neighbors within distance of range.
// indices and distances need to be pre-allocated of size l_search and the
// return value is the number of matching hits.
template<typename T>
_u32 PQFlashIndex<T>::range_search(const T *query1, const double range,
const _u64 min_l_search,
const _u64 max_l_search,
std::vector<_u64> &indices,
std::vector<float> &distances,
const _u64 min_beam_width,
QueryStats *stats) {
template<typename T, typename LabelT>
_u32 PQFlashIndex<T, LabelT>::range_search(
const T *query1, const double range, const _u64 min_l_search,
const _u64 max_l_search, std::vector<_u64> &indices,
std::vector<float> &distances, const _u64 min_beam_width,
QueryStats *stats) {
_u32 res_count = 0;
bool stop_flag = false;
@ -1162,23 +1466,23 @@ namespace diskann {
return res_count;
}
template<typename T>
_u64 PQFlashIndex<T>::get_data_dim() {
template<typename T, typename LabelT>
_u64 PQFlashIndex<T, LabelT>::get_data_dim() {
return data_dim;
}
template<typename T>
diskann::Metric PQFlashIndex<T>::get_metric() {
template<typename T, typename LabelT>
diskann::Metric PQFlashIndex<T, LabelT>::get_metric() {
return this->metric;
}
#ifdef EXEC_ENV_OLS
template<typename T>
char *PQFlashIndex<T>::getHeaderBytes() {
template<typename T, typename LabelT>
char *PQFlashIndex<T, LabelT>::getHeaderBytes() {
IOContext &ctx = reader->get_ctx();
AlignedRead readReq;
readReq.buf = new char[PQFlashIndex<T>::HEADER_SIZE];
readReq.len = PQFlashIndex<T>::HEADER_SIZE;
readReq.buf = new char[PQFlashIndex<T, LabelT>::HEADER_SIZE];
readReq.len = PQFlashIndex<T, LabelT>::HEADER_SIZE;
readReq.offset = 0;
std::vector<AlignedRead> readReqs;
@ -1194,5 +1498,8 @@ namespace diskann {
template class PQFlashIndex<_u8>;
template class PQFlashIndex<_s8>;
template class PQFlashIndex<float>;
template class PQFlashIndex<_u8, uint16_t>;
template class PQFlashIndex<_s8, uint16_t>;
template class PQFlashIndex<float, uint16_t>;
} // namespace diskann

Просмотреть файл

@ -87,15 +87,11 @@ namespace diskann {
}
template<typename T>
InMemorySearch<T>::InMemorySearch(
const std::string& baseFile,
const std::string& indexFile,
const std::string& tagsFile,
Metric m,
uint32_t num_threads,
uint32_t search_l
): BaseSearch(tagsFile) {
InMemorySearch<T>::InMemorySearch(const std::string& baseFile,
const std::string& indexFile,
const std::string& tagsFile, Metric m,
uint32_t num_threads, uint32_t search_l)
: BaseSearch(tagsFile) {
size_t dimensions, total_points = 0;
diskann::get_bin_metadata(baseFile, total_points, dimensions);
_index = std::unique_ptr<diskann::Index<T>>(
@ -136,9 +132,9 @@ namespace diskann {
}
template<typename T>
PQFlashSearch<T>::PQFlashSearch(const std::string& indexPrefix,
const unsigned num_nodes_to_cache,
const unsigned num_threads,
PQFlashSearch<T>::PQFlashSearch(const std::string& indexPrefix,
const unsigned num_nodes_to_cache,
const unsigned num_threads,
const std::string& tagsFile, Metric m)
: BaseSearch(tagsFile) {
#ifdef _WINDOWS
@ -162,7 +158,8 @@ namespace diskann {
int res = _index->load(num_threads, index_prefix_path.c_str());
if (res != 0) {
std::cerr << "Unable to load index. Status code: " << res << "." << std::endl;
std::cerr << "Unable to load index. Status code: " << res << "."
<< std::endl;
}
std::vector<uint32_t> node_list;
@ -174,13 +171,13 @@ namespace diskann {
}
template<typename T>
SearchResult PQFlashSearch<T>::search(const T* query,
SearchResult PQFlashSearch<T>::search(const T* query,
const unsigned int dimensions,
const unsigned int K,
const unsigned int Ls) {
_u64* indices_u64 = new _u64[K];
_u64* indices_u64 = new _u64[K];
unsigned* indices = new unsigned[K];
float* distances = new float[K];
float* distances = new float[K];
auto startTime = std::chrono::high_resolution_clock::now();
_index->cached_beam_search(query, K, Ls, indices_u64, distances, DEFAULT_W);

Просмотреть файл

@ -62,7 +62,7 @@ namespace diskann {
std::vector<size_t> pos(numsearchers, 0);
for (size_t k = 0; k < K; ++k) {
float best_distance = std::numeric_limits<float>::max();
float best_distance = std::numeric_limits<float>::max();
unsigned best_partition = 0;
for (size_t i = 0; i < numsearchers; ++i) {
@ -71,20 +71,23 @@ namespace diskann {
best_partition = i;
}
}
best_distances[k] = best_distance;
best_indices[k] = results[best_partition].get_indices()[pos[best_partition]];
best_partitions[k] = best_partition;
if (results[best_partition].tags_enabled())
best_tags[k] = results[best_partition].get_tags()[pos[best_partition]];
std::cout << best_partition << " " << pos[best_partition] << std::endl;
pos[best_partition]++;
best_distances[k] = best_distance;
best_indices[k] =
results[best_partition].get_indices()[pos[best_partition]];
best_partitions[k] = best_partition;
if (results[best_partition].tags_enabled())
best_tags[k] =
results[best_partition].get_tags()[pos[best_partition]];
std::cout << best_partition << " " << pos[best_partition] << std::endl;
pos[best_partition]++;
}
unsigned int total_time = 0;
for (size_t i = 0; i < numsearchers; ++i)
total_time += results[i].get_time();
diskann::SearchResult result = SearchResult(
K, total_time, best_indices, best_distances, best_tags, best_partitions);
diskann::SearchResult result =
SearchResult(K, total_time, best_indices, best_distances, best_tags,
best_partitions);
delete[] best_indices;
delete[] best_distances;
@ -101,8 +104,8 @@ namespace diskann {
void Server::handle_post(web::http::http_request message) {
message.extract_string(true)
.then([=](utility::string_t body) {
int64_t queryId = -1;
unsigned int K = 0;
int64_t queryId = -1;
unsigned int K = 0;
try {
T* queryVector = nullptr;
unsigned int dimensions = 0;
@ -113,7 +116,8 @@ namespace diskann {
std::vector<diskann::SearchResult> results;
for (auto& searcher : _multi_searcher)
results.push_back(searcher->search(queryVector, dimensions, (unsigned int) K, Ls));
results.push_back(searcher->search(queryVector, dimensions,
(unsigned int) K, Ls));
diskann::SearchResult result = aggregate_results(K, results);
diskann::aligned_free(queryVector);
web::json::value response = prepareResponse(queryId, K);
@ -139,13 +143,17 @@ namespace diskann {
return std::make_pair(web::http::status_codes::InternalError,
response);
} catch (...) {
std::cerr << "Uncaught exception while processing query: " << queryId;
std::cerr << "Uncaught exception while processing query: "
<< queryId;
web::json::value response = prepareResponse(queryId, K);
response[ERROR_MESSAGE_KEY] = web::json::value::string(UNKNOWN_ERROR);
return std::make_pair(web::http::status_codes::InternalError, response);
response[ERROR_MESSAGE_KEY] =
web::json::value::string(UNKNOWN_ERROR);
return std::make_pair(web::http::status_codes::InternalError,
response);
}
})
.then([=](std::pair<short unsigned int, web::json::value> response_status) {
.then([=](std::pair<short unsigned int, web::json::value>
response_status) {
try {
message.reply(response_status.first, response_status.second).wait();
} catch (const std::exception& ex) {
@ -180,7 +188,8 @@ namespace diskann {
if (k <= 0 || k > Ls) {
throw new std::invalid_argument(
"Num of expected NN (k) must be greater than zero and less than or equal to Ls.");
"Num of expected NN (k) must be greater than zero and less than or "
"equal to Ls.");
}
if (queryArr.size() == 0) {
throw new std::invalid_argument("Query vector has zero elements.");

Просмотреть файл

@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 14)
add_executable(build_memory_index build_memory_index.cpp)
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(build_stitched_index build_stitched_index.cpp)
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
add_executable(search_memory_index search_memory_index.cpp)
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)

Просмотреть файл

@ -13,11 +13,12 @@
namespace po = boost::program_options;
int main(int argc, char** argv) {
std::string data_type, dist_fn, data_path, index_path_prefix;
unsigned num_threads, R, L, disk_PQ, build_PQ;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
std::string data_type, dist_fn, data_path, index_path_prefix, label_file,
universal_label, label_type;
unsigned num_threads, R, L, disk_PQ, build_PQ, Lf, filter_threshold;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
po::options_description desc{"Arguments"};
try {
@ -60,9 +61,34 @@ int main(int argc, char** argv) {
desc.add_options()(
"build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
"Number of PQ bytes to build the index; 0 for full precision build");
desc.add_options()("use_opq", po::bool_switch()->default_value(false),
"Use Optimized Product Quantization (OPQ).");
desc.add_options()(
"label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index build ."
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
desc.add_options()(
"universal_label",
po::value<std::string>(&universal_label)->default_value(""),
"Universal label, Use only in conjuction with label file for filtered "
"index build. If a graph node has all the labels against it, we can "
"assign a special universal filter to the point instead of comma "
"separated filters for that point");
desc.add_options()("filtered_Lbuild,Lf",
po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
desc.add_options()(
"filter_threshold,F",
po::value<uint32_t>(&filter_threshold)->default_value(0),
"Threshold to break up the existing nodes to generate new graph "
"internally where each node has a maximum F labels.");
desc.add_options()(
"label_type",
po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@ -80,6 +106,11 @@ int main(int argc, char** argv) {
return -1;
}
bool use_filters = false;
if (label_file != "") {
use_filters = true;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
@ -116,21 +147,46 @@ int main(int argc, char** argv) {
std::string(std::to_string(build_PQ));
try {
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(),
index_path_prefix.c_str(),
params.c_str(), metric, use_opq);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric,
use_opq);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(data_path.c_str(),
index_path_prefix.c_str(),
params.c_str(), metric, use_opq);
else {
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
if (label_file != "" && label_type == "ushort") {
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else {
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
} else {
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, use_filters, label_file, universal_label,
filter_threshold, Lf);
else {
diskann::cerr << "Error. Unsupported data type" << std::endl;
return -1;
}
}
} catch (const std::exception& e) {
std::cout << std::string(e.what()) << std::endl;

Просмотреть файл

@ -20,44 +20,62 @@
namespace po = boost::program_options;
template<typename T, typename TagT = uint32_t>
template<typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
int build_in_memory_index(const diskann::Metric& metric,
const std::string& data_path, const unsigned R,
const unsigned L, const float alpha,
const std::string& save_path,
const unsigned num_threads, const bool use_pq_build,
const size_t num_pq_bytes, const bool use_opq) {
const size_t num_pq_bytes, const bool use_opq,
const std::string& label_file,
const std::string& universal_label, const _u32 Lf) {
diskann::Parameters paras;
paras.Set<unsigned>("R", R);
paras.Set<unsigned>("L", L);
paras.Set<unsigned>("Lf", Lf);
paras.Set<unsigned>(
"C", 750); // maximum candidate set size during pruning procedure
paras.Set<float>("alpha", alpha);
paras.Set<bool>("saturate_graph", 0);
paras.Set<unsigned>("num_threads", num_threads);
std::string labels_file_to_use = save_path + "_label_formatted.txt";
std::string mem_labels_int_map_file = save_path + "_labels_map.txt";
_u64 data_num, data_dim;
diskann::get_bin_metadata(data_path, data_num, data_dim);
diskann::Index<T, TagT> index(metric, data_dim, data_num, false, false, false,
use_pq_build, num_pq_bytes, use_opq);
auto s = std::chrono::high_resolution_clock::now();
index.build(data_path.c_str(), data_num, paras);
diskann::Index<T, TagT, LabelT> index(metric, data_dim, data_num, false,
false, false, use_pq_build,
num_pq_bytes, use_opq);
auto s = std::chrono::high_resolution_clock::now();
if (label_file == "") {
index.build(data_path.c_str(), data_num, paras);
} else {
convert_labels_string_to_int(label_file, labels_file_to_use,
mem_labels_int_map_file, universal_label);
if (universal_label != "") {
LabelT unv_label_as_num = std::stoul(universal_label);
index.set_universal_label(unv_label_as_num);
}
index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num,
paras);
}
std::chrono::duration<double> diff =
std::chrono::high_resolution_clock::now() - s;
std::cout << "Indexing time: " << diff.count() << "\n";
index.save(save_path.c_str());
if (label_file != "")
std::remove(labels_file_to_use.c_str());
return 0;
}
int main(int argc, char** argv) {
std::string data_type, dist_fn, data_path, index_path_prefix;
unsigned num_threads, R, L, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
std::string data_type, dist_fn, data_path, index_path_prefix, label_file,
universal_label, label_type;
unsigned num_threads, R, L, Lf, build_PQ_bytes;
float alpha;
bool use_pq_build, use_opq;
po::options_description desc{"Arguments"};
try {
@ -89,12 +107,31 @@ int main(int argc, char** argv) {
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()(
"build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
"build_PQ_bytes",
po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
"Number of PQ bytes to build the index; 0 for full precision build");
desc.add_options()(
"use_opq", po::bool_switch()->default_value(false),
"Set true for OPQ compression while using PQ distance comparisons for "
"building the index, and false for PQ compression");
desc.add_options()(
"label_file", po::value<std::string>(&label_file)->default_value(""),
"Input label file in txt format for Filtered Index search. "
"The file should contain comma separated filters for each node "
"with each line corresponding to a graph node");
desc.add_options()(
"universal_label",
po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with labels_file");
desc.add_options()("FilteredLbuild,Lf",
po::value<uint32_t>(&Lf)->default_value(0),
"Build complexity for filtered points, higher value "
"results in better graphs");
desc.add_options()(
"label_type",
po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
@ -128,22 +165,48 @@ int main(int argc, char** argv) {
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L
<< " alpha: " << alpha << " #threads: " << num_threads
<< std::endl;
if (data_type == std::string("int8"))
return build_in_memory_index<int8_t>(metric, data_path, R, L, alpha,
index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq);
else if (data_type == std::string("uint8"))
return build_in_memory_index<uint8_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq);
else if (data_type == std::string("float"))
return build_in_memory_index<float>(metric, data_path, R, L, alpha,
index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq);
else {
std::cout << "Unsupported type. Use one of int8, uint8 or float."
<< std::endl;
return -1;
if (label_file != "" && label_type == "ushort") {
if (data_type == std::string("int8"))
return build_in_memory_index<int8_t, uint32_t, uint16_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else if (data_type == std::string("uint8"))
return build_in_memory_index<uint8_t, uint32_t, uint16_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else if (data_type == std::string("float"))
return build_in_memory_index<float, uint32_t, uint16_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else {
std::cout << "Unsupported type. Use one of int8, uint8 or float."
<< std::endl;
return -1;
}
} else {
if (data_type == std::string("int8"))
return build_in_memory_index<int8_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else if (data_type == std::string("uint8"))
return build_in_memory_index<uint8_t>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else if (data_type == std::string("float"))
return build_in_memory_index<float>(
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
Lf);
else {
std::cout << "Unsupported type. Use one of int8, uint8 or float."
<< std::endl;
return -1;
}
}
} catch (const std::exception& e) {
std::cout << std::string(e.what()) << std::endl;

Просмотреть файл

@ -0,0 +1,855 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <boost/program_options.hpp>
#include <chrono>
#include <cstdio>
#include <cstring>
#include <random>
#include <string>
#include <tuple>
#include <omp.h>
#ifndef _WINDOWS
#include <sys/uio.h>
#endif
#include "index.h"
#include "memory_mapper.h"
#include "parameters.h"
#include "utils.h"
namespace po = boost::program_options;
// macros
#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
#define PBWIDTH 60
// custom types (for readability)
typedef tsl::robin_set<std::string> label_set;
typedef std::string path;
// structs for returning multiple items from a function
typedef std::tuple<std::vector<label_set>, tsl::robin_map<std::string, _u32>,
label_set>
parse_label_file_return_values;
typedef std::tuple<std::vector<std::vector<_u32>>, _u64>
load_label_index_return_values;
typedef std::tuple<std::vector<std::vector<_u32>>, _u64>
stitch_indices_return_values;
/*
* Inline function to display progress bar.
*/
inline void print_progress(double percentage) {
int val = (int) (percentage * 100);
int lpad = (int) (percentage * PBWIDTH);
int rpad = PBWIDTH - lpad;
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
fflush(stdout);
}
/*
* Inline function to generate a random integer in a range.
*/
inline size_t random(size_t range_from, size_t range_to) {
std::random_device rand_dev;
std::mt19937 generator(rand_dev());
std::uniform_int_distribution<size_t> distr(range_from, range_to);
return distr(generator);
}
/*
* function to handle command line parsing.
*
* Arguments are merely the inputs from the command line.
*/
void handle_args(int argc, char **argv, std::string &data_type,
path &input_data_path, path &final_index_path_prefix,
path &label_data_path, std::string &universal_label,
unsigned &num_threads, unsigned &R, unsigned &L,
unsigned &stitched_R, float &alpha) {
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("data_path",
po::value<path>(&input_data_path)->required(),
"Input data file in bin format");
desc.add_options()("index_path_prefix",
po::value<path>(&final_index_path_prefix)->required(),
"Path prefix for saving index file components");
desc.add_options()("max_degree,R",
po::value<uint32_t>(&R)->default_value(64),
"Maximum graph degree");
desc.add_options()(
"Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
"Build complexity, higher value results in better graphs");
desc.add_options()("stitched_R",
po::value<uint32_t>(&stitched_R)->default_value(100),
"Degree to prune final graph down to");
desc.add_options()(
"alpha", po::value<float>(&alpha)->default_value(1.2f),
"alpha controls density and diameter of graph, set 1 for sparse graph, "
"1.2 or 1.4 for denser graphs with lower diameter");
desc.add_options()(
"num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("label_file",
po::value<path>(&label_data_path)->default_value(""),
"Input label file in txt format if present");
desc.add_options()(
"universal_label",
po::value<std::string>(&universal_label)->default_value(""),
"If a point comes with the specified universal label (and only the "
"univ. "
"label), then the point is considered to have every possible label");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
exit(0);
}
po::notify(vm);
} catch (const std::exception &ex) {
std::cerr << ex.what() << '\n';
throw;
}
}
/*
* Parses the label datafile, which has comma-separated labels on
* each line. Line i corresponds to point id i.
*
* Returns three objects via std::tuple:
* 1. map: key is point id, value is vector of labels said point has
* 2. map: key is label, value is number of points with the label
* 3. the label universe as a set
*/
parse_label_file_return_values parse_label_file(path label_data_path,
std::string universal_label) {
std::ifstream label_data_stream(label_data_path);
std::string line, token;
unsigned line_cnt = 0;
// allows us to reserve space for the points_to_labels vector
while (std::getline(label_data_stream, line))
line_cnt++;
label_data_stream.clear();
label_data_stream.seekg(0, std::ios::beg);
// values to return
std::vector<label_set> point_ids_to_labels(line_cnt);
tsl::robin_map<std::string, _u32> labels_to_number_of_points;
label_set all_labels;
std::vector<_u32> points_with_universal_label;
line_cnt = 0;
while (std::getline(label_data_stream, line)) {
std::istringstream current_labels_comma_separated(line);
label_set current_labels;
// get point id
_u32 point_id = line_cnt;
// parse comma separated labels
bool current_universal_label_check = false;
while (getline(current_labels_comma_separated, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
// if token is empty, there's no labels for the point
if (token == universal_label) {
points_with_universal_label.push_back(point_id);
current_universal_label_check = true;
} else {
all_labels.insert(token);
current_labels.insert(token);
labels_to_number_of_points[token]++;
}
}
if (current_labels.size() <= 0 && !current_universal_label_check) {
std::cerr << "Error: " << point_id << " has no labels." << std::endl;
exit(-1);
}
point_ids_to_labels[point_id] = current_labels;
line_cnt++;
}
// for every point with universal label, set its label set to all labels
// also, increment the count for number of points a label has
for (const auto &point_id : points_with_universal_label) {
point_ids_to_labels[point_id] = all_labels;
for (const auto &lbl : all_labels)
labels_to_number_of_points[lbl]++;
}
std::cout << "Identified " << all_labels.size() << " distinct label(s) for "
<< point_ids_to_labels.size() << " points\n"
<< std::endl;
return std::make_tuple(point_ids_to_labels, labels_to_number_of_points,
all_labels);
}
/*
* For each label, generates a file containing all vectors that have said label.
* Also copies data from original bin file to new dimension-aligned file.
*
* Utilizes POSIX functions mmap and writev in order to minimize memory
* overhead, so we include an STL version as well.
*
* Each data file is saved under the following format:
* input_data_path + "_" + label
*/
template<typename T>
tsl::robin_map<std::string, std::vector<_u32>>
generate_label_specific_vector_files(
path input_data_path,
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels) {
auto file_writing_timer = std::chrono::high_resolution_clock::now();
diskann::MemoryMapper input_data(input_data_path);
char *input_start = input_data.getBuf();
_u32 number_of_points, dimension;
std::memcpy(&number_of_points, input_start, sizeof(_u32));
std::memcpy(&dimension, input_start + sizeof(_u32), sizeof(_u32));
const _u32 VECTOR_SIZE = dimension * sizeof(T);
const size_t METADATA = 2 * sizeof(_u32);
if (number_of_points != point_ids_to_labels.size()) {
std::cerr << "Error: number of points in labels file and data file differ."
<< std::endl;
throw;
}
tsl::robin_map<std::string, iovec *> label_to_iovec_map;
tsl::robin_map<std::string, _u32> label_to_curr_iovec;
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id;
// setup iovec list for each label
for (const auto &lbl : all_labels) {
iovec *label_iovecs =
(iovec *) malloc(labels_to_number_of_points[lbl] * sizeof(iovec));
if (label_iovecs == nullptr) {
throw;
}
label_to_iovec_map[lbl] = label_iovecs;
label_to_curr_iovec[lbl] = 0;
label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]);
}
// each point added to corresponding per-label iovec list
for (_u32 point_id = 0; point_id < number_of_points; point_id++) {
char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id);
iovec curr_iovec;
curr_iovec.iov_base = curr_point;
curr_iovec.iov_len = VECTOR_SIZE;
for (const auto &lbl : point_ids_to_labels[point_id]) {
*(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec;
label_to_curr_iovec[lbl]++;
label_id_to_orig_id[lbl].push_back(point_id);
}
}
// write each label iovec to resp. file
for (const auto &lbl : all_labels) {
int label_input_data_fd;
path curr_label_input_data_path(input_data_path + "_" + lbl);
_u32 curr_num_pts = labels_to_number_of_points[lbl];
label_input_data_fd =
open(curr_label_input_data_path.c_str(),
O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t) 0644);
if (label_input_data_fd == -1)
throw;
// write metadata
_u32 metadata[2] = {curr_num_pts, dimension};
int return_value = write(label_input_data_fd, metadata, sizeof(_u32) * 2);
if (return_value == -1) {
throw;
}
// limits on number of iovec structs per writev means we need to perform
// multiple writevs
size_t i = 0;
while (curr_num_pts > IOV_MAX) {
return_value = writev(label_input_data_fd,
(label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX);
if (return_value == -1) {
close(label_input_data_fd);
throw;
}
curr_num_pts -= IOV_MAX;
i += 1;
}
return_value =
writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)),
curr_num_pts);
if (return_value == -1) {
close(label_input_data_fd);
throw;
}
free(label_to_iovec_map[lbl]);
close(label_input_data_fd);
}
std::chrono::duration<double> file_writing_time =
std::chrono::high_resolution_clock::now() - file_writing_timer;
std::cout << "generated " << all_labels.size()
<< " label-specific vector files for index building in time "
<< file_writing_time.count() << "\n"
<< std::endl;
return label_id_to_orig_id;
}
// for use on systems without writev (i.e. Windows)
template<typename T>
tsl::robin_map<std::string, std::vector<_u32>>
generate_label_specific_vector_files_compat(
path input_data_path,
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
std::vector<label_set> point_ids_to_labels, label_set all_labels) {
auto file_writing_timer = std::chrono::high_resolution_clock::now();
std::ifstream input_data_stream(input_data_path);
_u32 number_of_points, dimension;
input_data_stream.read((char *) &number_of_points, sizeof(_u32));
input_data_stream.read((char *) &dimension, sizeof(_u32));
const _u32 VECTOR_SIZE = dimension * sizeof(T);
if (number_of_points != point_ids_to_labels.size()) {
std::cerr << "Error: number of points in labels file and data file differ."
<< std::endl;
throw;
}
tsl::robin_map<std::string, char *> labels_to_vectors;
tsl::robin_map<std::string, _u32> labels_to_curr_vector;
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id;
for (const auto &lbl : all_labels) {
_u32 number_of_label_pts = labels_to_number_of_points[lbl];
char *vectors = (char *) malloc(number_of_label_pts * VECTOR_SIZE);
if (vectors == nullptr) {
throw;
}
labels_to_vectors[lbl] = vectors;
labels_to_curr_vector[lbl] = 0;
label_id_to_orig_id[lbl].reserve(number_of_label_pts);
}
for (_u32 point_id = 0; point_id < number_of_points; point_id++) {
char *curr_vector = (char *) malloc(VECTOR_SIZE);
input_data_stream.read(curr_vector, VECTOR_SIZE);
for (const auto &lbl : point_ids_to_labels[point_id]) {
char *curr_label_vector_ptr =
labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE);
memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE);
labels_to_curr_vector[lbl]++;
label_id_to_orig_id[lbl].push_back(point_id);
}
free(curr_vector);
}
for (const auto &lbl : all_labels) {
path curr_label_input_data_path(input_data_path + "_" + lbl);
_u32 number_of_label_pts = labels_to_number_of_points[lbl];
std::ofstream label_file_stream;
label_file_stream.exceptions(std::ios::badbit | std::ios::failbit);
label_file_stream.open(curr_label_input_data_path, std::ios_base::binary);
label_file_stream.write((char *) &number_of_label_pts, sizeof(_u32));
label_file_stream.write((char *) &dimension, sizeof(_u32));
label_file_stream.write((char *) labels_to_vectors[lbl],
number_of_label_pts * VECTOR_SIZE);
label_file_stream.close();
free(labels_to_vectors[lbl]);
}
input_data_stream.close();
std::chrono::duration<double> file_writing_time =
std::chrono::high_resolution_clock::now() - file_writing_timer;
std::cout << "generated " << all_labels.size()
<< " label-specific vector files for index building in time "
<< file_writing_time.count() << "\n"
<< std::endl;
return label_id_to_orig_id;
}
/*
* Using passed in parameters and files generated from step 3,
* builds a vanilla diskANN index for each label.
*
* Each index is saved under the following path:
* final_index_path_prefix + "_" + label
*/
template<typename T>
void generate_label_indices(path input_data_path, path final_index_path_prefix,
label_set all_labels, unsigned R, unsigned L,
float alpha, unsigned num_threads) {
diskann::Parameters label_index_build_parameters;
label_index_build_parameters.Set<unsigned>("R", R);
label_index_build_parameters.Set<unsigned>("L", L);
label_index_build_parameters.Set<unsigned>("C", 750);
label_index_build_parameters.Set<unsigned>("Lf", 0);
label_index_build_parameters.Set<bool>("saturate_graph", 0);
label_index_build_parameters.Set<float>("alpha", alpha);
label_index_build_parameters.Set<unsigned>("num_threads", num_threads);
std::cout << "Generating indices per label..." << std::endl;
// for each label, build an index on resp. points
double total_indexing_time = 0.0, indexing_percentage = 0.0;
std::cout.setstate(std::ios_base::failbit);
diskann::cout.setstate(std::ios_base::failbit);
for (const auto &lbl : all_labels) {
path curr_label_input_data_path(input_data_path + "_" + lbl);
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
size_t number_of_label_points, dimension;
diskann::get_bin_metadata(curr_label_input_data_path,
number_of_label_points, dimension);
diskann::Index<T> index(diskann::Metric::L2, dimension,
number_of_label_points, false, false);
auto index_build_timer = std::chrono::high_resolution_clock::now();
index.build(curr_label_input_data_path.c_str(), number_of_label_points,
label_index_build_parameters);
std::chrono::duration<double> current_indexing_time =
std::chrono::high_resolution_clock::now() - index_build_timer;
total_indexing_time += current_indexing_time.count();
indexing_percentage += (1 / (double) all_labels.size());
print_progress(indexing_percentage);
index.save(curr_label_index_path.c_str());
}
std::cout.clear();
diskann::cout.clear();
std::cout << "\nDone. Generated per-label indices in " << total_indexing_time
<< " seconds\n"
<< std::endl;
}
/*
* Manually loads a graph index in from a given file.
*
* Returns both the graph index and the size of the file in bytes.
*/
load_label_index_return_values load_label_index(path label_index_path,
_u32 label_number_of_points) {
std::ifstream label_index_stream;
label_index_stream.exceptions(std::ios::badbit | std::ios::failbit);
label_index_stream.open(label_index_path, std::ios::binary);
_u64 index_file_size, index_num_frozen_points;
_u32 index_max_observed_degree, index_entry_point;
const size_t INDEX_METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
label_index_stream.read((char *) &index_file_size, sizeof(_u64));
label_index_stream.read((char *) &index_max_observed_degree, sizeof(_u32));
label_index_stream.read((char *) &index_entry_point, sizeof(_u32));
label_index_stream.read((char *) &index_num_frozen_points, sizeof(_u64));
size_t bytes_read = INDEX_METADATA;
std::vector<std::vector<_u32>> label_index(label_number_of_points);
_u32 nodes_read = 0;
while (bytes_read != index_file_size) {
_u32 current_node_num_neighbors;
label_index_stream.read((char *) &current_node_num_neighbors, sizeof(_u32));
nodes_read++;
std::vector<_u32> current_node_neighbors(current_node_num_neighbors);
label_index_stream.read((char *) current_node_neighbors.data(),
current_node_num_neighbors * sizeof(_u32));
label_index[nodes_read - 1].swap(current_node_neighbors);
bytes_read += sizeof(_u32) * (current_node_num_neighbors + 1);
}
return std::make_tuple(label_index, index_file_size);
}
/*
* Custom index save to write the in-memory index to disk.
* Also writes required files for diskANN API -
* 1. labels_to_medoids
* 2. universal_label
* 3. data (redundant for static indices)
* 4. labels (redundant for static indices)
*/
void save_full_index(path final_index_path_prefix, path input_data_path,
_u64 final_index_size,
std::vector<std::vector<_u32>> stitched_graph,
tsl::robin_map<std::string, _u32> entry_points,
std::string universal_label, path label_data_path) {
// aux. file 1
auto saving_index_timer = std::chrono::high_resolution_clock::now();
std::ifstream original_label_data_stream;
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_label_data_stream.open(label_data_path, std::ios::binary);
std::ofstream new_label_data_stream;
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_label_data_stream.open(final_index_path_prefix + "_labels.txt",
std::ios::binary);
new_label_data_stream << original_label_data_stream.rdbuf();
original_label_data_stream.close();
new_label_data_stream.close();
// aux. file 2
std::ifstream original_input_data_stream;
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
original_input_data_stream.open(input_data_path, std::ios::binary);
std::ofstream new_input_data_stream;
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
new_input_data_stream.open(final_index_path_prefix + ".data",
std::ios::binary);
new_input_data_stream << original_input_data_stream.rdbuf();
original_input_data_stream.close();
new_input_data_stream.close();
// aux. file 3
std::ofstream labels_to_medoids_writer;
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
labels_to_medoids_writer.open(final_index_path_prefix +
"_labels_to_medoids.txt");
for (auto iter : entry_points)
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
labels_to_medoids_writer.close();
// aux. file 4 (only if we're using a universal label)
if (universal_label != "") {
std::ofstream universal_label_writer;
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
universal_label_writer.open(final_index_path_prefix +
"_universal_label.txt");
universal_label_writer << universal_label << std::endl;
universal_label_writer.close();
}
// main index
_u64 index_num_frozen_points = 0, index_num_edges = 0;
_u32 index_max_observed_degree = 0, index_entry_point = 0;
const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
for (auto &point_neighbors : stitched_graph) {
index_max_observed_degree =
std::max(index_max_observed_degree, (_u32) point_neighbors.size());
}
std::ofstream stitched_graph_writer;
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
stitched_graph_writer.write((char *) &final_index_size, sizeof(_u64));
stitched_graph_writer.write((char *) &index_max_observed_degree,
sizeof(_u32));
stitched_graph_writer.write((char *) &index_entry_point, sizeof(_u32));
stitched_graph_writer.write((char *) &index_num_frozen_points, sizeof(_u64));
size_t bytes_written = METADATA;
for (_u32 node_point = 0; node_point < stitched_graph.size(); node_point++) {
_u32 current_node_num_neighbors = stitched_graph[node_point].size();
std::vector<_u32> current_node_neighbors = stitched_graph[node_point];
stitched_graph_writer.write((char *) &current_node_num_neighbors,
sizeof(_u32));
bytes_written += sizeof(_u32);
for (const auto &current_node_neighbor : current_node_neighbors) {
stitched_graph_writer.write((char *) &current_node_neighbor,
sizeof(_u32));
bytes_written += sizeof(_u32);
}
index_num_edges += current_node_num_neighbors;
}
if (bytes_written != final_index_size) {
std::cerr << "Error: written bytes does not match allocated space"
<< std::endl;
throw;
}
stitched_graph_writer.close();
std::chrono::duration<double> saving_index_time =
std::chrono::high_resolution_clock::now() - saving_index_timer;
std::cout << "Stitched graph written in " << saving_index_time.count()
<< " seconds" << std::endl;
std::cout << "Stitched graph average degree: "
<< ((float) index_num_edges) / ((float) (stitched_graph.size()))
<< std::endl;
std::cout << "Stitched graph max degree: " << index_max_observed_degree
<< std::endl
<< std::endl;
}
/*
* Unions the per-label graph indices together via the following policy:
* - any two nodes can only have at most one edge between them -
*
* Returns the "stitched" graph and its expected file size.
*/
template<typename T>
stitch_indices_return_values stitch_label_indices(
path final_index_path_prefix, _u32 total_number_of_points,
label_set all_labels,
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
tsl::robin_map<std::string, _u32> &label_entry_points,
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id_map) {
size_t final_index_size = 0;
std::vector<std::vector<_u32>> stitched_graph(total_number_of_points);
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
for (const auto &lbl : all_labels) {
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
std::vector<std::vector<_u32>> curr_label_index;
_u64 curr_label_index_size;
_u32 curr_label_entry_point;
std::tie(curr_label_index, curr_label_index_size) = load_label_index(
curr_label_index_path, labels_to_number_of_points[lbl]);
curr_label_entry_point = random(0, curr_label_index.size());
label_entry_points[lbl] =
label_id_to_orig_id_map[lbl][curr_label_entry_point];
for (_u32 node_point = 0; node_point < curr_label_index.size();
node_point++) {
_u32 original_point_id = label_id_to_orig_id_map[lbl][node_point];
for (auto &node_neighbor : curr_label_index[node_point]) {
_u32 original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
std::vector<_u32> curr_point_neighbors =
stitched_graph[original_point_id];
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(),
original_neighbor_id) == curr_point_neighbors.end()) {
stitched_graph[original_point_id].push_back(original_neighbor_id);
final_index_size += sizeof(_u32);
}
}
}
}
const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
final_index_size += (total_number_of_points * sizeof(_u32) + METADATA);
std::chrono::duration<double> stitching_index_time =
std::chrono::high_resolution_clock::now() - stitching_index_timer;
std::cout << "stitched graph generated in memory in "
<< stitching_index_time.count() << " seconds" << std::endl;
return std::make_tuple(stitched_graph, final_index_size);
}
/*
* Applies the prune_neighbors function from src/index.cpp to
* every node in the stitched graph.
*
* This is an optional step, hence the saving of both the full
* and pruned graph.
*/
template<typename T>
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix,
path input_data_path,
std::vector<std::vector<_u32>> stitched_graph,
unsigned stitched_R,
tsl::robin_map<std::string, _u32> label_entry_points,
std::string universal_label, path label_data_path,
unsigned num_threads) {
size_t dimension, number_of_label_points;
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
auto std_cout_buffer = std::cout.rdbuf(nullptr);
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
diskann::Index<T> index(diskann::Metric::L2, dimension,
number_of_label_points, false, false);
// not searching this index, set search_l to 0
index.load(full_index_path_prefix.c_str(), num_threads, 1);
diskann::Parameters paras;
paras.Set<unsigned>("R", stitched_R);
paras.Set<unsigned>(
"C", 750); // maximum candidate set size during pruning procedure
paras.Set<float>("alpha", 1.2);
paras.Set<bool>("saturate_graph", 1);
std::cout << "parsing labels" << std::endl;
index.prune_all_nbrs(paras);
index.save((final_index_path_prefix).c_str());
diskann::cout.rdbuf(diskann_cout_buffer);
std::cout.rdbuf(std_cout_buffer);
std::chrono::duration<double> pruning_index_time =
std::chrono::high_resolution_clock::now() - pruning_index_timer;
std::cout << "pruning performed in " << pruning_index_time.count()
<< " seconds\n"
<< std::endl;
}
/*
* Delete all temporary artifacts.
* In the process of creating the stitched index, some temporary artifacts are
* created:
* 1. the separate bin files for each labels' points
* 2. the separate diskANN indices built for each label
* 3. the '.data' file created while generating the indices
*/
void clean_up_artifacts(path input_data_path, path final_index_path_prefix,
label_set all_labels) {
for (const auto &lbl : all_labels) {
path curr_label_input_data_path(input_data_path + "_" + lbl);
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
path curr_label_index_path_data(curr_label_index_path + ".data");
if (std::remove(curr_label_index_path.c_str()) != 0)
throw;
if (std::remove(curr_label_input_data_path.c_str()) != 0)
throw;
if (std::remove(curr_label_index_path_data.c_str()) != 0)
throw;
}
}
int main(int argc, char **argv) {
// 1. handle cmdline inputs
std::string data_type;
path input_data_path, final_index_path_prefix, label_data_path;
std::string universal_label;
unsigned num_threads, R, L, stitched_R;
float alpha;
auto index_timer = std::chrono::high_resolution_clock::now();
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix,
label_data_path, universal_label, num_threads, R, L, stitched_R,
alpha);
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
convert_labels_string_to_int(label_data_path, labels_file_to_use,
labels_map_file, universal_label);
// 2. parse label file and create necessary data structures
std::vector<label_set> point_ids_to_labels;
tsl::robin_map<std::string, _u32> labels_to_number_of_points;
label_set all_labels;
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
parse_label_file(labels_file_to_use, universal_label);
// 3. for each label, make a separate data file
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id_map;
_u32 total_number_of_points = point_ids_to_labels.size();
#ifndef _WINDOWS
if (data_type == "uint8")
label_id_to_orig_id_map = generate_label_specific_vector_files<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map = generate_label_specific_vector_files<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else if (data_type == "float")
label_id_to_orig_id_map = generate_label_specific_vector_files<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else
throw;
#else
if (data_type == "uint8")
label_id_to_orig_id_map =
generate_label_specific_vector_files_compat<uint8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else if (data_type == "int8")
label_id_to_orig_id_map =
generate_label_specific_vector_files_compat<int8_t>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else if (data_type == "float")
label_id_to_orig_id_map =
generate_label_specific_vector_files_compat<float>(
input_data_path, labels_to_number_of_points, point_ids_to_labels,
all_labels);
else
throw;
#endif
// 4. for each created data file, create a vanilla diskANN index
if (data_type == "uint8")
generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix,
all_labels, R, L, alpha, num_threads);
else if (data_type == "int8")
generate_label_indices<int8_t>(input_data_path, final_index_path_prefix,
all_labels, R, L, alpha, num_threads);
else if (data_type == "float")
generate_label_indices<float>(input_data_path, final_index_path_prefix,
all_labels, R, L, alpha, num_threads);
else
throw;
// 5. "stitch" the indices together
std::vector<std::vector<_u32>> stitched_graph;
tsl::robin_map<std::string, _u32> label_entry_points;
_u64 stitched_graph_size;
if (data_type == "uint8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<uint8_t>(
final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points,
label_id_to_orig_id_map);
else if (data_type == "int8")
std::tie(stitched_graph, stitched_graph_size) =
stitch_label_indices<int8_t>(
final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points,
label_id_to_orig_id_map);
else if (data_type == "float")
std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices<float>(
final_index_path_prefix, total_number_of_points, all_labels,
labels_to_number_of_points, label_entry_points,
label_id_to_orig_id_map);
else
throw;
path full_index_path_prefix = final_index_path_prefix + "_full";
// 5a. save the stitched graph to disk
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size,
stitched_graph, label_entry_points, universal_label,
labels_file_to_use);
// 6. run a prune on the stitched index, and save to disk
if (data_type == "uint8")
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix,
input_data_path, stitched_graph, stitched_R,
label_entry_points, universal_label,
labels_file_to_use, num_threads);
else if (data_type == "int8")
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix,
input_data_path, stitched_graph, stitched_R,
label_entry_points, universal_label,
labels_file_to_use, num_threads);
else if (data_type == "float")
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix,
input_data_path, stitched_graph, stitched_R,
label_entry_points, universal_label,
labels_file_to_use, num_threads);
else
throw;
std::chrono::duration<double> index_time =
std::chrono::high_resolution_clock::now() - index_timer;
std::cout << "pruned/stitched graph generated in " << index_time.count()
<< " seconds" << std::endl;
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
}

Просмотреть файл

@ -47,7 +47,7 @@ void print_stats(std::string category, std::vector<float> percentiles,
diskann::cout << std::endl;
}
template<typename T>
template<typename T, typename LabelT = uint32_t>
int search_disk_index(diskann::Metric& metric,
const std::string& index_path_prefix,
const std::string& query_file, std::string& gt_file,
@ -99,8 +99,8 @@ int search_disk_index(diskann::Metric& metric,
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T>> _pFlashIndex(
new diskann::PQFlashIndex<T>(reader, metric));
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
@ -202,8 +202,8 @@ int search_disk_index(diskann::Metric& metric,
std::vector<_u64> indices;
std::vector<float> distances;
_u32 res_count = _pFlashIndex->range_search(
query + (i * query_aligned_dim), search_range, L, max_list_size,
indices, distances, optimized_beamwidth, stats + i);
query + (i * query_aligned_dim), search_range, L, max_list_size,
indices, distances, optimized_beamwidth, stats + i);
query_result_ids[test_id][i].reserve(res_count);
query_result_ids[test_id][i].resize(res_count);
for (_u32 idx = 0; idx < res_count; idx++)

Просмотреть файл

@ -12,7 +12,6 @@
#include <codecvt>
#include <boost/program_options.hpp>
#include <cpprest/http_client.h>
#include <restapi/common.h>
@ -65,47 +64,40 @@ void query_loop(const std::string& ip_addr_port, const std::string& query_file,
}
int main(int argc, char* argv[]) {
std::string data_type, query_file, address;
uint32_t num_queries;
uint32_t l_search, k_value;
std::string data_type, query_file, address;
uint32_t num_queries;
uint32_t l_search, k_value;
po::options_description desc{ "Arguments" };
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address",
po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("query_file",
po::value<std::string>(&query_file)->required(),
"File containing the queries to search");
desc.add_options()(
"num_queries,Q",
po::value<uint32_t>(&num_queries)->required(),
"Number of queries to search");
desc.add_options()(
"l_search",
po::value<uint32_t>(&l_search)->required(),
"Value of L");
desc.add_options()(
"k_value,K",
po::value<uint32_t>(&k_value)->default_value(10),
"Value of K (default 10)");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("query_file",
po::value<std::string>(&query_file)->required(),
"File containing the queries to search");
desc.add_options()("num_queries,Q",
po::value<uint32_t>(&num_queries)->required(),
"Number of queries to search");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(),
"Value of L");
desc.add_options()("k_value,K",
po::value<uint32_t>(&k_value)->default_value(10),
"Value of K (default 10)");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
po::notify(vm);
} catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
if (data_type == std::string("float")) {
query_loop<float>(address, query_file, num_queries, l_search, k_value);

Просмотреть файл

@ -9,7 +9,6 @@
#include <codecvt>
#include <boost/program_options.hpp>
#include <restapi/server.h>
using namespace diskann;
@ -37,99 +36,96 @@ void teardown(const utility::string_t& address) {
}
int main(int argc, char* argv[]) {
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
uint32_t num_threads;
uint32_t l_search;
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
uint32_t num_threads;
uint32_t l_search;
po::options_description desc{ "Arguments" };
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address",
po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("data_file",
po::value<std::string>(&data_file)->required(),
"File containing the data found in the index");
desc.add_options()("index_path_prefix",
po::value<std::string>(&index_file)->required(),
"Path prefix for saving index file components");
desc.add_options()(
"num_threads,T",
po::value<uint32_t>(&num_threads)->required(),
"Number of threads used for building index");
desc.add_options()(
"l_search",
po::value<uint32_t>(&l_search)->required(),
"Value of L");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file",
po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else {
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("data_file",
po::value<std::string>(&data_file)->required(),
"File containing the data found in the index");
desc.add_options()("index_path_prefix",
po::value<std::string>(&index_file)->required(),
"Path prefix for saving index file components");
desc.add_options()("num_threads,T",
po::value<uint32_t>(&num_threads)->required(),
"Number of threads used for building index");
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(),
"Value of L");
desc.add_options()("dist_fn",
po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()(
"tags_file",
po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
} catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else {
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
}
if (data_type == std::string("float")) {
auto searcher =
if (data_type == std::string("float")) {
auto searcher =
std::unique_ptr<diskann::BaseSearch>(new diskann::InMemorySearch<float>(
data_file, index_file, tags_file, metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
} else if (data_type == std::string("int8")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file,
metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
} else if (data_type == std::string("uint8")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file,
metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
} else {
std::cerr << "Unsupported data type " << argv[2] << std::endl;
}
g_inMemorySearch.push_back(std::move(searcher));
} else if (data_type == std::string("int8")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file,
metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
} else if (data_type == std::string("uint8")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file,
metric, num_threads, l_search));
g_inMemorySearch.push_back(std::move(searcher));
} else {
std::cerr << "Unsupported data type " << argv[2] << std::endl;
}
while (1) {
try {
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit") {
teardown(address);
g_httpServer->close().wait();
exit(0);
}
} catch (const std::exception& ex) {
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
} catch (...) {
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
while (1) {
try {
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit") {
teardown(address);
g_httpServer->close().wait();
exit(0);
}
} catch (const std::exception& ex) {
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
} catch (...) {
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

Просмотреть файл

@ -37,134 +37,135 @@ void teardown(const utility::string_t& address) {
}
int main(int argc, char* argv[]) {
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{ "Arguments" };
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("address",
po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("index_prefix_paths",
po::value<std::string>(&index_prefix_paths)->required(),
"Path prefix for loading index file components");
desc.add_options()(
"num_nodes_to_cache",
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()(
"num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file",
po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("address", po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("index_prefix_paths",
po::value<std::string>(&index_prefix_paths)->required(),
"Path prefix for loading index file components");
desc.add_options()(
"num_nodes_to_cache",
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
"Number of nodes to cache during search");
desc.add_options()(
"num_threads,T",
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn",
po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()(
"tags_file",
po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
}
catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else {
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
} catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
}
std::vector<std::pair<std::string, std::string>> index_tag_paths;
std::ifstream index_in(index_prefix_paths);
if (!index_in.is_open()) {
std::cerr << "Could not open " << index_prefix_paths << std::endl;
exit(-1);
}
std::ifstream tags_in(tags_file);
if (!tags_in.is_open()) {
std::cerr << "Could not open " << tags_file << std::endl;
exit(-1);
}
std::string prefix, tagfile;
while (std::getline(index_in, prefix)) {
if (std::getline(tags_in, tagfile)) {
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
} else {
std::cerr << "The number of tags specified does not match the number of "
"indices specified" << std::endl;
exit(-1);
}
}
index_in.close();
tags_in.close();
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
else {
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
}
if (data_type == std::string("float")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(index_tag.first.c_str(),
num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
} else if (data_type == std::string("int8")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<int8_t>(index_tag.first.c_str(),
num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
} else if (data_type == std::string("uint8")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(index_tag.first.c_str(),
num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
std::vector<std::pair<std::string, std::string>> index_tag_paths;
std::ifstream index_in(index_prefix_paths);
if (!index_in.is_open()) {
std::cerr << "Could not open " << index_prefix_paths << std::endl;
exit(-1);
}
std::ifstream tags_in(tags_file);
if (!tags_in.is_open()) {
std::cerr << "Could not open " << tags_file << std::endl;
exit(-1);
}
std::string prefix, tagfile;
while (std::getline(index_in, prefix)) {
if (std::getline(tags_in, tagfile)) {
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
} else {
std::cerr << "Unsupported data type " << data_type << std::endl;
exit(-1);
std::cerr << "The number of tags specified does not match the number of "
"indices specified"
<< std::endl;
exit(-1);
}
}
index_in.close();
tags_in.close();
if (data_type == std::string("float")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(index_tag.first.c_str(),
num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
while (1) {
try {
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit") {
teardown(address);
g_httpServer->close().wait();
exit(0);
}
} catch (const std::exception& ex) {
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
} catch (...) {
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
} else if (data_type == std::string("int8")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<int8_t>(index_tag.first.c_str(),
num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
} else if (data_type == std::string("uint8")) {
for (auto& index_tag : index_tag_paths) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(
index_tag.first.c_str(), num_nodes_to_cache, num_threads,
index_tag.second.c_str(), metric));
g_ssdSearch.push_back(std::move(searcher));
}
} else {
std::cerr << "Unsupported data type " << data_type << std::endl;
exit(-1);
}
while (1) {
try {
setup(address, data_type);
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
std::string line;
std::getline(std::cin, line);
if (line == "exit") {
teardown(address);
g_httpServer->close().wait();
exit(0);
}
} catch (const std::exception& ex) {
std::cerr << "Exception occurred: " << ex.what() << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
} catch (...) {
std::cerr << "Unknown exception occurreed" << std::endl;
std::cerr << "Restarting HTTP server";
teardown(address);
}
}
}

Просмотреть файл

@ -10,8 +10,6 @@
#include <boost/program_options.hpp>
#include <omp.h>
#include <restapi/server.h>
using namespace diskann;
@ -22,7 +20,7 @@ std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
void setup(const utility::string_t& address, const std::string& typestring) {
web::http::uri_builder uriBldr(address);
auto uri = uriBldr.to_uri();
auto uri = uriBldr.to_uri();
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
@ -40,21 +38,20 @@ void teardown(const utility::string_t& address) {
int main(int argc, char* argv[]) {
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
uint32_t num_nodes_to_cache;
uint32_t num_threads;
po::options_description desc{ "Arguments" };
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("data_type",
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address",
po::value<std::string>(&address)->required(),
"Web server address");
po::value<std::string>(&data_type)->required(),
"data type <int8/uint8/float>");
desc.add_options()("address", po::value<std::string>(&address)->required(),
"Web server address");
desc.add_options()("index_path_prefix",
po::value<std::string>(&index_path_prefix)->required(),
"Path prefix for loading index file components");
po::value<std::string>(&index_path_prefix)->required(),
"Path prefix for loading index file components");
desc.add_options()(
"num_nodes_to_cache",
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
@ -64,52 +61,52 @@ int main(int argc, char* argv[]) {
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
"Number of threads used for building index (defaults to "
"omp_get_num_procs())");
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()("tags_file",
desc.add_options()("dist_fn",
po::value<std::string>(&dist_fn)->default_value("l2"),
"distance function <l2/mips>");
desc.add_options()(
"tags_file",
po::value<std::string>(&tags_file)->default_value(std::string()),
"Tags file location");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
std::cout << desc;
return 0;
}
po::notify(vm);
} catch (const std::exception& ex) {
std::cerr << ex.what() << std::endl;
return -1;
std::cerr << ex.what() << std::endl;
return -1;
}
diskann::Metric metric;
if (dist_fn == std::string("l2"))
metric = diskann::Metric::L2;
metric = diskann::Metric::L2;
else if (dist_fn == std::string("mips"))
metric = diskann::Metric::INNER_PRODUCT;
metric = diskann::Metric::INNER_PRODUCT;
else {
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
std::cout << "Error. Only l2 and mips distance functions are supported"
<< std::endl;
return -1;
}
if (data_type == std::string("float")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<float>(
index_path_prefix, num_nodes_to_cache, num_threads,
tags_file, metric));
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache,
num_threads, tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
} else if (data_type == std::string("int8")) {
auto searcher =
std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
index_path_prefix, num_nodes_to_cache, num_threads,
tags_file, metric));
index_path_prefix, num_nodes_to_cache, num_threads, tags_file,
metric));
g_ssdSearch.push_back(std::move(searcher));
} else if (data_type == std::string("uint8")) {
auto searcher = std::unique_ptr<diskann::BaseSearch>(
new diskann::PQFlashSearch<uint8_t>(
index_path_prefix, num_nodes_to_cache, num_threads,
tags_file, metric));
new diskann::PQFlashSearch<uint8_t>(index_path_prefix,
num_nodes_to_cache, num_threads,
tags_file, metric));
g_ssdSearch.push_back(std::move(searcher));
} else {
std::cerr << "Unsupported data type " << argv[2] << std::endl;

Просмотреть файл

@ -44,14 +44,15 @@ void print_stats(std::string category, std::vector<float> percentiles,
diskann::cout << std::endl;
}
template<typename T>
template<typename T, typename LabelT = uint32_t>
int search_disk_index(
diskann::Metric& metric, const std::string& index_path_prefix,
const std::string& result_output_prefix, const std::string& query_file,
std::string& gt_file, const unsigned num_threads, const unsigned recall_at,
const unsigned beamwidth, const unsigned num_nodes_to_cache,
const _u32 search_io_limit, const std::vector<unsigned>& Lvec,
const bool use_reorder_data, const float fail_if_recall_below) {
const float fail_if_recall_below, const bool use_reorder_data = false,
const std::string& filter_label = "") {
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
@ -62,6 +63,10 @@ int search_disk_index(
else
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
bool filtered_search = false;
if (filter_label != "")
filtered_search = true;
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
// load query bin
@ -95,8 +100,8 @@ int search_disk_index(
reader.reset(new LinuxAlignedFileReader());
#endif
std::unique_ptr<diskann::PQFlashIndex<T>> _pFlashIndex(
new diskann::PQFlashIndex<T>(reader, metric));
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
@ -207,11 +212,22 @@ int search_disk_index(
#pragma omp parallel for schedule(dynamic, 1)
for (_s64 i = 0; i < (int64_t) query_num; i++) {
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, search_io_limit, use_reorder_data, stats + i);
if (!filtered_search) {
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, use_reorder_data, stats + i);
} else {
LabelT label_for_search =
_pFlashIndex->get_converted_label(filter_label);
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, true, label_for_search, use_reorder_data,
stats + i);
}
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e - s;
@ -282,7 +298,7 @@ int search_disk_index(
int main(int argc, char** argv) {
std::string data_type, dist_fn, index_path_prefix, result_path_prefix,
query_file, gt_file;
query_file, gt_file, filter_label, label_type;
unsigned num_threads, K, W, num_nodes_to_cache, search_io_limit;
std::vector<unsigned> Lvec;
bool use_reorder_data = false;
@ -333,6 +349,15 @@ int main(int argc, char** argv) {
po::bool_switch()->default_value(false),
"Include full precision data in the index. Use only in "
"conjuction with compressed data on SSD.");
desc.add_options()(
"filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
"Filter Label for Filtered Search");
desc.add_options()(
"label_type",
po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
desc.add_options()(
"fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
@ -382,29 +407,52 @@ int main(int argc, char** argv) {
}
try {
if (data_type == std::string("float"))
return search_disk_index<float>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
use_reorder_data, fail_if_recall_below);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
use_reorder_data, fail_if_recall_below);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
use_reorder_data, fail_if_recall_below);
else {
std::cerr << "Unsupported data type. Use float or int8 or uint8"
<< std::endl;
return -1;
if (filter_label != "" && label_type == "ushort") {
if (data_type == std::string("float"))
return search_disk_index<float, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else {
std::cerr << "Unsupported data type. Use float or int8 or uint8"
<< std::endl;
return -1;
}
} else {
if (data_type == std::string("float"))
return search_disk_index<float>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, use_reorder_data, filter_label);
else {
std::cerr << "Unsupported data type. Use float or int8 or uint8"
<< std::endl;
return -1;
}
}
} catch (const std::exception& e) {
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Index search failed." << std::endl;
return -1;
}
}
}

Просмотреть файл

@ -23,7 +23,7 @@
namespace po = boost::program_options;
template<typename T>
template<typename T, typename LabelT = uint32_t>
int search_memory_index(diskann::Metric& metric, const std::string& index_path,
const std::string& result_path_prefix,
const std::string& query_file,
@ -32,7 +32,8 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
const bool print_all_recalls,
const std::vector<unsigned>& Lvec, const bool dynamic,
const bool tags, const bool show_qps_per_thread,
const float fail_if_recall_below) {
const std::string& filter_label,
const float fail_if_recall_below) {
// Load the query file
T* query = nullptr;
unsigned* gt_ids = nullptr;
@ -54,8 +55,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
diskann::cout << " Truthset file " << truthset_file
<< " not found. Not computing recall." << std::endl;
}
bool filtered_search = false;
if (filter_label != "") {
filtered_search = true;
}
using TagT = uint32_t;
diskann::Index<T, TagT> index(metric, query_dim, 0, dynamic, tags);
diskann::Index<T, TagT, LabelT> index(metric, query_dim, 0, dynamic, tags);
std::cout << "Index class instantiated" << std::endl;
index.load(index_path.c_str(), num_threads,
*(std::max_element(Lvec.begin(), Lvec.end())));
@ -117,6 +124,7 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
}
query_result_ids[test_id].resize(recall_at * query_num);
query_result_dists[test_id].resize(recall_at * query_num);
std::vector<T*> res = std::vector<T*>();
auto s = std::chrono::high_resolution_clock::now();
@ -124,7 +132,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
#pragma omp parallel for schedule(dynamic, 1)
for (int64_t i = 0; i < (int64_t) query_num; i++) {
auto qs = std::chrono::high_resolution_clock::now();
if (metric == diskann::FAST_L2) {
if (filtered_search) {
LabelT filter_label_as_num = index.get_converted_label(filter_label);
auto retval = index.search_with_filters(
query + i * query_aligned_dim, filter_label_as_num, recall_at, L,
query_result_ids[test_id].data() + i * recall_at,
query_result_dists[test_id].data() + i * recall_at);
cmp_stats[i] = retval.second;
} else if (metric == diskann::FAST_L2) {
index.search_with_optimized_layout(
query + i * query_aligned_dim, recall_at, L,
query_result_ids[test_id].data() + i * recall_at);
@ -212,10 +227,9 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
return best_recall >= fail_if_recall_below ? 0 : -1;
}
int main(int argc, char** argv) {
std::string data_type, dist_fn, index_path_prefix, result_path, query_file,
gt_file;
gt_file, filter_label, label_type;
unsigned num_threads, K;
std::vector<unsigned> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
@ -238,6 +252,15 @@ int main(int argc, char** argv) {
desc.add_options()("query_file",
po::value<std::string>(&query_file)->required(),
"Query file in binary format");
desc.add_options()(
"filter_label",
po::value<std::string>(&filter_label)->default_value(std::string("")),
"Filter Label for Filtered Search");
desc.add_options()(
"label_type",
po::value<std::string>(&label_type)->default_value("uint"),
"Storage type of Labels <uint/ushort>, default value is uint which "
"will consume memory 4 bytes per filter");
desc.add_options()(
"gt_file",
po::value<std::string>(&gt_file)->default_value(std::string("null")),
@ -313,26 +336,46 @@ int main(int argc, char** argv) {
}
try {
if (data_type == std::string("int8")) {
return search_memory_index<int8_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, fail_if_recall_below);
}
else if (data_type == std::string("uint8")) {
return search_memory_index<uint8_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, fail_if_recall_below);
} else if (data_type == std::string("float")) {
return search_memory_index<float>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, fail_if_recall_below);
if (filter_label != "" && label_type == "ushort") {
if (data_type == std::string("int8")) {
return search_memory_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else if (data_type == std::string("uint8")) {
return search_memory_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else if (data_type == std::string("float")) {
return search_memory_index<float, uint16_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else {
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
} else {
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
if (data_type == std::string("int8")) {
return search_memory_index<int8_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else if (data_type == std::string("uint8")) {
return search_memory_index<uint8_t>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else if (data_type == std::string("float")) {
return search_memory_index<float>(
metric, index_path_prefix, result_path, query_file, gt_file,
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
show_qps_per_thread, filter_label, fail_if_recall_below);
} else {
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
return -1;
}
}
} catch (std::exception& e) {
std::cout << std::string(e.what()) << std::endl;

Просмотреть файл

@ -160,6 +160,8 @@ void build_incremental_index(
params.Set<bool>("saturate_graph", saturate_graph);
params.Set<unsigned>("num_rnds", 1);
params.Set<unsigned>("num_threads", thread_count);
params.Set<int>("Lf", 0); // TODO: get this from params and default to some
// value to make it backward compatible.
size_t dim, aligned_dim;
size_t num_points;

Просмотреть файл

@ -83,9 +83,10 @@ std::string get_save_filename(const std::string& save_path,
return final_path;
}
template<typename T, typename TagT>
void insert_next_batch(diskann::Index<T, TagT>& index, size_t start, size_t end,
size_t insert_threads, T* data, size_t aligned_dim) {
template<typename T, typename TagT, typename LabelT>
void insert_next_batch(diskann::Index<T, TagT, LabelT>& index, size_t start,
size_t end, size_t insert_threads, T* data,
size_t aligned_dim) {
try {
diskann::Timer insert_timer;
std::cout << std::endl
@ -116,8 +117,8 @@ void insert_next_batch(diskann::Index<T, TagT>& index, size_t start, size_t end,
}
}
template<typename T, typename TagT>
void delete_and_consolidate(diskann::Index<T, TagT>& index,
template<typename T, typename TagT, typename LabelT>
void delete_and_consolidate(diskann::Index<T, TagT, LabelT>& index,
diskann::Parameters& delete_params, size_t start,
size_t end) {
try {
@ -189,6 +190,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
params.Set<bool>("saturate_graph", saturate_graph);
params.Set<unsigned>("num_rnds", 1);
params.Set<unsigned>("num_threads", insert_threads);
params.Set<unsigned>("Lf", 0);
diskann::Parameters delete_params;
delete_params.Set<unsigned>("L", L);
delete_params.Set<unsigned>("R", R);
@ -222,6 +224,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
__FUNCSIG__, __FILE__, __LINE__);
using TagT = uint32_t;
using LabelT = uint32_t;
unsigned num_frozen = 1;
const bool enable_tags = true;
@ -232,9 +235,9 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
std::cout << "Overriding num_frozen to" << num_frozen << std::endl;
}
diskann::Index<T, TagT> index(diskann::L2, dim,
active_window + 4 * consolidate_interval, true,
params, params, enable_tags, true);
diskann::Index<T, TagT, LabelT> index(
diskann::L2, dim, active_window + 4 * consolidate_interval, true, params,
params, enable_tags, true);
index.set_start_point_at_random(static_cast<T>(start_point_norm));
index.enable_delete();

Просмотреть файл

@ -66,3 +66,8 @@ target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK
add_executable(create_disk_layout create_disk_layout.cpp)
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
add_executable(stats_label_data stats_label_data.cpp)
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)

Просмотреть файл

@ -302,6 +302,66 @@ inline void load_bin_as_float(const char *filename, float *&data,
std::cout << "Finished converting part data to float." << std::endl;
}
template<typename T>
inline std::vector<size_t> load_filtered_bin_as_float(
const char *filename, float *&data, size_t &npts, size_t &ndims,
int part_num, const char *label_file, const std::string &filter_label,
const std::string &universal_label, size_t &npoints_filt,
std::vector<std::vector<std::string>> &pts_to_labels) {
std::ifstream reader(filename, std::ios::binary);
if (reader.fail()) {
throw diskann::ANNException(std::string("Failed to open file ") + filename,
-1);
}
std::cout << "Reading bin file " << filename << " ...\n";
int npts_i32, ndims_i32;
std::vector<size_t> rev_map;
reader.read((char *) &npts_i32, sizeof(int));
reader.read((char *) &ndims_i32, sizeof(int));
uint64_t start_id = part_num * PARTSIZE;
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t) npts_i32);
npts = end_id - start_id;
ndims = (unsigned) ndims_i32;
uint64_t nptsuint64_t = (uint64_t) npts;
uint64_t ndimsuint64_t = (uint64_t) ndims;
npoints_filt = 0;
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B"
<< std::endl;
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t),
std::ios::beg);
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
reader.read((char *) data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
std::cout << "Finished reading part of the bin file." << std::endl;
reader.close();
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
for (int64_t i = 0; i < (int64_t) nptsuint64_t; i++) {
if (std::find(pts_to_labels[start_id + i].begin(),
pts_to_labels[start_id + i].end(),
filter_label) != pts_to_labels[start_id + i].end() ||
std::find(pts_to_labels[start_id + i].begin(),
pts_to_labels[start_id + i].end(),
universal_label) != pts_to_labels[start_id + i].end()) {
rev_map.push_back(start_id + i);
for (int64_t j = 0; j < (int64_t) ndimsuint64_t; j++) {
float cur_val_float = (float) data_T[i * ndimsuint64_t + j];
std::memcpy((char *) (data + npoints_filt * ndimsuint64_t + j),
(char *) &cur_val_float, sizeof(float));
}
npoints_filt++;
}
}
delete[] data_T;
std::cout << "Finished converting part data to float.. identified "
<< npoints_filt << " points matching the filter." << std::endl;
return rev_map;
}
template<typename T>
inline void save_bin(const std::string filename, T *data, size_t npts,
size_t ndims) {
@ -334,19 +394,51 @@ inline void save_groundtruth_as_one_file(const std::string filename,
<< 2 * npts * ndims * sizeof(unsigned) + 2 * sizeof(int) << "B"
<< std::endl;
// data = new T[npts_u64 * ndims_u64];
writer.write((char *) data, npts * ndims * sizeof(uint32_t));
writer.write((char *) distances, npts * ndims * sizeof(float));
writer.close();
std::cout << "Finished writing truthset" << std::endl;
}
inline void parse_label_file_into_vec(
size_t &line_cnt, const std::string &map_file,
std::vector<std::vector<std::string>> &pts_to_labels) {
std::ifstream infile(map_file);
std::string line, token;
std::set<std::string> labels;
infile.clear();
infile.seekg(0, std::ios::beg);
while (std::getline(infile, line)) {
std::istringstream iss(line);
std::vector<std::string> lbls(0);
getline(iss, token, '\t');
std::istringstream new_iss(token);
while (getline(new_iss, token, ',')) {
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
lbls.push_back(token);
labels.insert(token);
}
if (lbls.size() <= 0) {
std::cout << "No label found";
exit(-1);
}
std::sort(lbls.begin(), lbls.end());
pts_to_labels.push_back(lbls);
}
std::cout << "Identified " << labels.size()
<< " distinct label(s), and populated labels for "
<< pts_to_labels.size() << " points" << std::endl;
}
template<typename T>
int aux_main(const std::string &base_file, const std::string &query_file,
const std::string &gt_file, size_t k,
const diskann::Metric &metric,
const std::string &tags_file = std::string("")) {
size_t npoints, nqueries, dim;
int aux_main(const std::string &base_file, const std::string &label_file,
const std::string &query_file, const std::string &gt_file,
size_t k, const std::string &filter_label,
const std::string &universal_label, const diskann::Metric &metric,
const std::string &tags_file = std::string("")) {
size_t npoints, nqueries, dim, npoints_filt;
float *base_data;
float *query_data;
@ -392,25 +484,50 @@ int aux_main(const std::string &base_file, const std::string &query_file,
int *closest_points = new int[nqueries * k];
float *dist_closest_points = new float[nqueries * k];
std::vector<std::vector<std::string>> pts_to_labels;
if (filter_label != "")
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
std::vector<size_t> rev_map;
for (int p = 0; p < num_parts; p++) {
size_t start_id = p * PARTSIZE;
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
if (filter_label == "") {
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
} else {
rev_map = load_filtered_bin_as_float<T>(
base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
filter_label, universal_label, npoints_filt, pts_to_labels);
}
int *closest_points_part = new int[nqueries * k];
float *dist_closest_points_part = new float[nqueries * k];
auto nr = std::min(npoints, k);
exact_knn(dim, nr, closest_points_part, dist_closest_points_part, npoints,
base_data, nqueries, query_data, metric);
_u32 part_k;
if (filter_label == "") {
part_k = k < npoints ? k : npoints;
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part,
npoints, base_data, nqueries, query_data, metric);
} else {
part_k = k < npoints_filt ? k : npoints_filt;
if (npoints_filt > 0) {
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part,
npoints_filt, base_data, nqueries, query_data, metric);
}
}
for (_u64 i = 0; i < nqueries; i++) {
for (_u64 j = 0; j < nr; j++) {
for (_u64 j = 0; j < part_k; j++) {
if (tags_enabled)
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
continue;
results[i].push_back(std::make_pair(
(uint32_t) (closest_points_part[i * nr + j] + start_id),
dist_closest_points_part[i * nr + j]));
if (filter_label == "") {
results[i].push_back(std::make_pair(
(uint32_t) (closest_points_part[i * part_k + j] + start_id),
dist_closest_points_part[i * part_k + j]));
} else {
results[i].push_back(std::make_pair(
(uint32_t) (rev_map[closest_points_part[i * part_k + j]]),
dist_closest_points_part[i * part_k + j]));
}
}
}
@ -455,8 +572,9 @@ int aux_main(const std::string &base_file, const std::string &query_file,
}
int main(int argc, char **argv) {
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
uint64_t K;
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file,
label_file, filter_label, universal_label;
uint64_t K;
try {
po::options_description desc{"Arguments"};
@ -474,6 +592,17 @@ int main(int argc, char **argv) {
desc.add_options()("query_file",
po::value<std::string>(&query_file)->required(),
"File containing the query vectors in binary format");
desc.add_options()("label_file",
po::value<std::string>(&label_file)->default_value(""),
"Input labels file in txt format if present");
desc.add_options()("filter_label",
po::value<std::string>(&filter_label)->default_value(""),
"Input filter label if doing filtered groundtruth");
desc.add_options()(
"universal_label",
po::value<std::string>(&universal_label)->default_value(""),
"Universal label, if using it, only in conjunction with label_file");
desc.add_options()(
"gt_file", po::value<std::string>(&gt_file)->required(),
"File name for the writing ground truth in binary format");
@ -518,11 +647,14 @@ int main(int argc, char **argv) {
try {
if (data_type == std::string("float"))
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
aux_main<float>(base_file, label_file, query_file, gt_file, K,
filter_label, universal_label, metric, tags_file);
if (data_type == std::string("int8"))
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K,
filter_label, universal_label, metric, tags_file);
if (data_type == std::string("uint8"))
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K,
filter_label, universal_label, metric, tags_file);
} catch (const std::exception &e) {
std::cout << std::string(e.what()) << std::endl;
diskann::cerr << "Compute GT failed." << std::endl;

Просмотреть файл

@ -26,8 +26,9 @@ namespace po = boost::program_options;
template<typename T>
void bfs_count(const std::string& index_path, unsigned data_dims) {
using TagT = uint32_t;
diskann::Index<T, TagT> index(diskann::Metric::L2, data_dims, 0, false,
false);
using LabelT = uint32_t;
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0,
false, false);
std::cout << "Index class instantiated" << std::endl;
index.load(index_path.c_str(), 1, 100);
std::cout << "Index loaded" << std::endl;

Просмотреть файл

@ -0,0 +1,169 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <random>
#include <boost/program_options.hpp>
#include <math.h>
#include <cmath>
#include "utils.h"
namespace po = boost::program_options;
class ZipfDistribution {
public:
ZipfDistribution(int num_points, int num_labels)
: uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)),
num_points(num_points), num_labels(num_labels) {
}
std::unordered_map<int, int> createDistributionMap() {
std::unordered_map<int, int> map;
int primary_label_freq = ceil(num_points * distribution_factor);
for (int i{1}; i < num_labels + 1; i++) {
map[i] = ceil(primary_label_freq / i);
}
return map;
}
int writeDistribution(std::ofstream& outfile) {
auto distribution_map = createDistributionMap();
auto primary_label_frequency = num_points * distribution_factor;
for (int i{0}; i < num_points; i++) {
bool label_written = false;
for (auto it = distribution_map.cbegin(), next_it = it;
it != distribution_map.cend(); it = next_it) {
next_it++;
auto label_selection_probability = std::bernoulli_distribution(
distribution_factor / (double) it->first);
if (label_selection_probability(rand_engine)) {
if (label_written) {
outfile << ',';
}
outfile << it->first;
label_written = true;
// remove label from map if we have used all labels
distribution_map[it->first] -= 1;
if (distribution_map[it->first] == 0) {
distribution_map.erase(it);
}
}
}
if (!label_written) {
outfile << 0;
}
if (i < num_points - 1) {
outfile << '\n';
}
}
return 0;
}
int writeDistribution(std::string filename) {
std::ofstream outfile(filename);
if (!outfile.is_open()) {
std::cerr << "Error: could not open output file " << filename << '\n';
return -1;
}
writeDistribution(outfile);
outfile.close();
}
private:
int num_labels;
const int num_points;
const double distribution_factor = 0.7;
std::knuth_b rand_engine;
const std::uniform_real_distribution<double> uniform_zero_to_one;
};
int main(int argc, char** argv) {
std::string output_file, distribution_type;
_u64 num_labels, num_points;
try {
po::options_description desc{"Arguments"};
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("output_file,O",
po::value<std::string>(&output_file)->required(),
"Filename for saving the label file");
desc.add_options()("num_points,N",
po::value<uint64_t>(&num_points)->required(),
"Number of points in dataset");
desc.add_options()("num_labels,L",
po::value<uint64_t>(&num_labels)->required(),
"Number of unique labels, up to 5000");
desc.add_options()(
"distribution_type,DT",
po::value<std::string>(&distribution_type)->default_value("random"),
"Distribution function for labels <random/zipf> defaults to random");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
} catch (const std::exception& ex) {
std::cerr << ex.what() << '\n';
return -1;
}
if (num_labels > 5000) {
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
return -1;
}
if (num_points <= 0) {
std::cerr << "Error: num_points must be greater than 0" << '\n';
return -1;
}
std::cout << "Generating synthetic labels for " << num_points
<< " points with " << num_labels << " unique labels" << '\n';
try {
std::ofstream outfile(output_file);
if (!outfile.is_open()) {
std::cerr << "Error: could not open output file " << output_file << '\n';
return -1;
}
if (distribution_type == "zipf") {
ZipfDistribution zipf(num_points, num_labels);
zipf.writeDistribution(outfile);
} else if (distribution_type == "random") {
for (int i = 0; i < num_points; i++) {
bool label_written = false;
for (int j = 1; j <= num_labels; j++) {
// 50% chance to assign each label
if (rand() > (RAND_MAX / 2)) {
if (label_written) {
outfile << ',';
}
outfile << j;
label_written = true;
}
}
if (!label_written) {
outfile << 0;
}
if (i < num_points - 1) {
outfile << '\n';
}
}
}
if (outfile.is_open()) {
outfile.close();
}
std::cout << "Labels written to " << output_file << '\n';
} catch (const std::exception& ex) {
std::cerr << "Label generation failed: " << ex.what() << '\n';
return -1;
}
return 0;
}

Просмотреть файл

@ -0,0 +1,150 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <omp.h>
#include <string.h>
#include <atomic>
#include <cstring>
#include <iomanip>
#include <set>
#include <boost/program_options.hpp>
#include "utils.h"
#ifndef _WINDOWS
#include <sys/mman.h>
#include <unistd.h>
#include <sys/stat.h>
#include <time.h>
#else
#include <Windows.h>
#endif
namespace po = boost::program_options;
void stats_analysis(const std::string labels_file, std::string univeral_label,
_u32 density = 10) {
std::string token, line;
std::ifstream labels_stream(labels_file);
std::unordered_map<std::string, _u32> label_counts;
std::string label_with_max_points;
_u32 max_points = 0;
long long sum = 0;
long long point_cnt = 0;
float avg_labels_per_pt, avg_labels_per_pt_incl_0, mean_label_size,
mean_label_size_incl_0;
std::vector<_u32> labels_per_point;
_u32 dense_pts = 0;
if (labels_stream.is_open()) {
while (getline(labels_stream, line)) {
point_cnt++;
std::stringstream iss(line);
_u32 lbl_cnt = 0;
while (getline(iss, token, ',')) {
lbl_cnt++;
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
if (label_counts.find(token) == label_counts.end())
label_counts[token] = 0;
label_counts[token]++;
}
if (lbl_cnt >= density) {
dense_pts++;
}
labels_per_point.emplace_back(lbl_cnt);
}
}
std::cout << "fraction of dense points with >= " << density << " labels = "
<< (float) dense_pts / (float) labels_per_point.size() << std::endl;
std::sort(labels_per_point.begin(), labels_per_point.end());
std::vector<std::pair<std::string, _u32>> label_count_vec;
for (auto it = label_counts.begin(); it != label_counts.end(); it++) {
auto& lbl = *it;
label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second));
if (lbl.second > max_points) {
max_points = lbl.second;
label_with_max_points = lbl.first;
}
sum += lbl.second;
}
sort(label_count_vec.begin(), label_count_vec.end(),
[](const std::pair<std::string, _u32>& lhs,
const std::pair<std::string, _u32>& rhs) {
return lhs.second < rhs.second;
});
for (float p = 0; p < 1; p += 0.05) {
std::cout << "Percentile " << (100 * p) << "\t"
<< label_count_vec[(_u32) (p * label_count_vec.size())].first
<< " with count="
<< label_count_vec[(_u32) (p * label_count_vec.size())].second
<< std::endl;
}
std::cout << "Most common label "
<< "\t" << label_count_vec[label_count_vec.size() - 1].first
<< " with count="
<< label_count_vec[label_count_vec.size() - 1].second << std::endl;
if (label_count_vec.size() > 1)
std::cout << "Second common label "
<< "\t" << label_count_vec[label_count_vec.size() - 2].first
<< " with count="
<< label_count_vec[label_count_vec.size() - 2].second
<< std::endl;
if (label_count_vec.size() > 2)
std::cout << "Third common label "
<< "\t" << label_count_vec[label_count_vec.size() - 3].first
<< " with count="
<< label_count_vec[label_count_vec.size() - 3].second
<< std::endl;
avg_labels_per_pt = (sum) / (float) point_cnt;
mean_label_size = (sum) / label_counts.size();
std::cout << "Total number of points = " << point_cnt
<< ", number of labels = " << label_counts.size() << std::endl;
std::cout << "Average number of labels per point = " << avg_labels_per_pt
<< std::endl;
std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl;
std::cout << "Most popular label is " << label_with_max_points << " with "
<< max_points << " pts" << std::endl;
}
int main(int argc, char** argv) {
std::string labels_file, universal_label;
_u32 density;
po::options_description desc{"Arguments"};
try {
desc.add_options()("help,h", "Print information on arguments");
desc.add_options()("labels_file",
po::value<std::string>(&labels_file)->required(),
"path to labels data file.");
desc.add_options()("universal_label",
po::value<std::string>(&universal_label)->required(),
"Universal label used in labels file.");
desc.add_options()(
"density", po::value<_u32>(&density)->default_value(1),
"Number of labels each point in labels file, defaults to 1");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, desc), vm);
if (vm.count("help")) {
std::cout << desc;
return 0;
}
po::notify(vm);
} catch (const std::exception& e) {
std::cerr << e.what() << '\n';
return -1;
}
stats_analysis(labels_file, universal_label, density);
}

Просмотреть файл

@ -0,0 +1,126 @@
**Usage for filtered indices**
================================
## Building a filtered Index
DiskANN provides two algorithms for building an index with filters support: filtered-vamana and stitched-vamana. Here, we describe the parameters for building both. `tests/build_memory_index.cpp` and `tests/build_stitched_index.cpp` are respectively used to build each kind of index.
### 1. filtered-vamana
1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported.
2. **`--dist_fn`**: There are two distance functions supported: minimum Euclidean distance (l2) and maximum inner product (mips).
3. **`--data_file`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices.
4. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix.
5. **`-R (--max_degree)`** (default is 64): the degree of the graph index, typically between 32 and 150. Larger R will result in larger indices and longer indexing times, but might yield better search quality.
6. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality. Note that this is to be used only for building an unfiltered index. The corresponding search list parameter for a filtered index is managed by `--FilteredLbuild`.
7. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs.
8. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth).
9. **`--build_PQ_bytes`** (default is 0): Set to a positive value less than the dimensionality of the data to enable faster index build with PQ based distance comparisons. Defaults to using full precision vectors for distance comparisons.
10. **`--use_opq`**: use the flag to use OPQ rather than PQ compression. OPQ is more space efficient for some high dimensional datasets, but also needs a bit more build time.
11. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
12. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point.
13. **`--FilteredLbuild`**: If building a filtered index, we maintain a separate search list from the one provided by `--Lbuild`.
### 2. stitched-vamana
1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported.
2. **`--data_path`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices.
3. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix.
4. **`-R (--max_degree)`** (default is 64): Recall that stitched-vamana first builds a sub-index for each filter. This parameter sets the max degree for each sub-index.
5. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during sub-index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality.
6. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs.
7. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth).
8. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
9. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point.
10. **`--Stitched_R`**: Once all sub-indices are "stitched" together, we prune the resulting graph down to the degree given by this parameter.
## Computing a groundtruth file for a filtered index
In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `tests/utils/compute_groundtruth.cpp` to provide the results for the latter:
1. **`--data_type`** The type of dataset you built an index with. float(32 bit), signed int8 and unsigned uint8 are supported.
2. **`--dist_fn`**: There are two distance functions supported: l2 and mips.
3. **`--base_file`**: The input data over which to build an index, in .bin format. Corresponds to the `--data_path` argument from above.
4. **`--query_file`**: The queries to be searched on, which are stored in the same .bin format.
5. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
6. **`--filter_label`**: Filter for each query. For each query, a search is performed with this filter.
7. **`--universal_label`**: Corresponds to the universal label passed when building an index with filter support.
8. **`--gt_file`**: File to output results to. The binary file starts with `n`, the number of queries (4 bytes), followed by `d`, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the `d` closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes.
9. **`-K`**: The number of nearest neighbors to compute for each query.
## Searching a Filtered Index
Searching a filtered index uses the `tests/search_memory_index.cpp`:
1. **`--data_type`**: The type of dataset you built the index on. float(32 bit), signed int8 and unsigned uint8 are supported. Use the same data type as in arg (1) above used in building the index.
2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. There is an additional *fast_l2* implementation that could provide faster results for small (about a million-sized) indices. Use the same distance as in arg (2) above used in building the index. Note that stitched-vamana only supports l2.
3. **`--index_path_prefix`**: index built above in argument (4).
4. **`--result_path`**: search results will be stored in files, one per L value (see last arg), with specified prefix, in binary format.
5. **`-T (--num_threads)`**: The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but may lead to higher per-query latency, especially if the DRAM bandwidth is a bottleneck. So find the balance depending on throughput and latency required for your application.
6. **`--query_file`**: The queries to be searched on in same binary file format as the data file (ii) above. The query file must be the same type as in argument (1).
7. **`--filter_label`**: The filter to be used when searching an index with filters. For each query, a search is performed with this filter.
8. **`--gt_file`**: The ground truth file for the queries and data file used in index construction. Use "null" if you do not have this file and if you do not want to compute recall. Note that if building a filtered index, a special groundtruth must be computed, as described above.
9. **`-K`**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors.
10. **`-L (--search_list)`**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in (7).
Example with SIFT10K:
--------------------
We demonstrate how to work through this pipeline using the SIFT10K dataset (http://corpus-texmex.irisa.fr/). Before starting, make sure you have compiled diskANN according to the instructions in the README and can see the following binaries (paths with respect to repository root):
- `build/tests/utils/compute_groundtruth`
- `build/tests/utils/fvecs_to_bin`
- `build/tests/build_memory_index`
- `build/tests/build_stitched_index`
- `build/tests/search_memory_index`
Now, download the base and query set and convert the data to binary format:
```bash
wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz
tar -zxvf siftsmall.tar.gz
build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin
build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin
```
We now need to make label file for our vectors. For convenience, we've included a synthetic label generator through which we can generate label file as follow
```bash
build/tests/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf
```
Note : `distribution_type` can be `rand` or `zipf`
This will genearate label file with 10000 data points with 50 distinct labels, ranging from 1 to 50 assigned using zipf distribution (0 is the universal label).
Label count for each unique label in the generated label file can be printed with help of following command
```bash
build/tests/utils/stats_label_data.exe --labels_file ./rand_labels_50_10K.txt --universal_label 0
```
Note that neither approach is designed for use with random synthetic labels, which will lead to unpredictable accuracy at search time.
Now build and search the index and measure the recall using ground truth computed using bruteforce. We search for results with the filter 35.
```bash
build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --K 100 --label_file ./rand_labels_50_10K.txt --filter_label 35 --universal_label 0
build/tests/build_memory_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R32_L50_filtered_index -R 32 --FilteredLbuild 50 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0
build/tests/build_stitched_index --data_type float --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R20_L40_SR32_stitched_index -R 20 -L 40 --stitched_R 32 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0
build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_filtered_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/filtered_search_results
build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_stitched_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/stitched_search_results
```
The output of both searches is listed below. The throughput (Queries/sec) as well as mean and 99.9 latency in microseconds for each `L` parameter provided. (Measured on a physical machine with a Intel(R) Xeon(R) W-2145 CPU and 64 GB RAM)
```
Stitched Index
Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10
=================================================================================
10 31324.39 37.33 116.79 311.90 17.80
20 91357.57 44.36 193.06 1042.30 17.90
30 69314.48 49.89 258.09 1398.00 18.20
40 61421.29 60.52 289.08 1515.00 18.60
50 54203.48 70.27 294.26 685.10 19.40
100 52904.45 79.00 336.26 1018.80 19.50
Filtered Index
Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10
=================================================================================
10 69671.84 21.48 45.25 146.20 11.60
20 168577.20 38.94 100.54 547.90 18.20
30 127129.41 52.95 126.83 768.40 19.70
40 106349.04 62.38 167.23 899.10 20.90
50 89952.33 70.95 189.12 1070.80 22.10
100 56899.00 112.26 304.67 636.60 23.80
```