Rebasing main's latest commits onto ravi/filter_support_rebased (#225)
- add code for two variants of filtered index, readme and CI tests - add utils for synthetic label generation and CI tests. * Add co-authors Co-authored-by: ravishankar <rakri@microsoft.com> Co-authored-by: Varun Sivashankar <t-varunsi@microsoft.com> --------- Co-authored-by: ravishankar <rakri@microsoft.com> Co-authored-by: David Kaczynski <dkaczynski@microsoft.com> Co-authored-by: Siddharth Gollapudi <t-gollapudis@microsoft.com> Co-authored-by: Neelam Mahapatro <nmahapatro@microsoft.com> Co-authored-by: Harsha Vardhan Simhadri <harshasi@microsoft.com> Co-authored-by: Harsha Vardhan Simhadri <harsha-simhadri@users.noreply.github.com> Co-authored-by: REDMOND\patelyash <patelyash@microsoft.com> Co-authored-by: Varun Sivashankar <t-varunsi@microsoft.com>
This commit is contained in:
Родитель
5ba6a5d2c2
Коммит
5ec769aa85
|
@ -6,6 +6,7 @@ jobs:
|
|||
runs-on: ${{ matrix.os }}
|
||||
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-2019, windows-latest]
|
||||
|
||||
|
@ -71,6 +72,9 @@ jobs:
|
|||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: Searching with fast_l2 distance function
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with MIPS metric
|
||||
run: |
|
||||
|
@ -117,7 +121,6 @@ jobs:
|
|||
|
||||
|
||||
- name: Generate 10K random int8 index vectors, 1K query vectors, in 10 dims and compute GT
|
||||
run: |
|
||||
run: |
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
|
@ -154,10 +157,11 @@ jobs:
|
|||
${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search an incremental index
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200;
|
||||
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_1K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: test a streaming index
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
|
@ -165,6 +169,7 @@ jobs:
|
|||
|
||||
|
||||
- name: Generate 10K random uint8 index vectors, 1K query vectors, in 10 dims and compute GT
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
|
@ -172,44 +177,77 @@ jobs:
|
|||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn mips --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100
|
||||
- name: build and search in-memory index with L2 metrics
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with cosine metric
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn cosine --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_cosine_rand_uint8_10D_10K_norm50.0
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search in-memory index with L2 metrics with PQ base distance comparisons
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1
|
||||
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5
|
||||
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (sharded graph build, L2, no diskPQ)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006
|
||||
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search disk index (one shot graph build, L2, diskPQ)
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5
|
||||
${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16
|
||||
- name: build and search an incremental index
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200;
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_10K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1
|
||||
- name: test a streaming index
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1
|
||||
|
||||
- name: Generate 10K random uint8 index vectors, 1K query vectors, 10K Label Points (50 unique labels), in 10 dims and compute GT
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0
|
||||
${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0
|
||||
${{ env.diskann_built_utils }}/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_10_10K.txt
|
||||
${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100
|
||||
- name: build and search in-memory index with labels using L2 metrics
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
- name: build and search in-memory index with pq_dist of 5 with 10 dimensions
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5
|
||||
${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32
|
||||
|
||||
- name: Build and search stitched vamana
|
||||
if: success() || failure()
|
||||
run: |
|
||||
${{ env.diskann_built_tests }}/build_stitched_index --num_threads 48 --data_type uint8 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix ./stit_32_100_64_new --universal_label 0
|
||||
${{ env.diskann_built_tests }}/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix ./stit_32_100_64_new --query_file ./rand_uint8_10D_1K_norm50.0.bin --result_path ./rand_stit_96_10_90_new --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 10 10 10 10 10 30 50 70 90 110 130 150 170 190 210 230 250 270 290 310 330 350 370 390 410
|
||||
- uses: actions/setup-python@v3
|
||||
- name: Install cibuildwheel
|
||||
run: python -m pip install cibuildwheel==2.11.3
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "x64-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "Release",
|
||||
"inheritEnvironments": [ "msvc_x64" ],
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": ""
|
||||
},
|
||||
{
|
||||
"name": "WSL-GCC-Release",
|
||||
"generator": "Ninja",
|
||||
"configurationType": "RelWithDebInfo",
|
||||
"buildRoot": "${projectDir}\\out\\build\\${name}",
|
||||
"installRoot": "${projectDir}\\out\\install\\${name}",
|
||||
"cmakeExecutable": "cmake",
|
||||
"cmakeCommandArgs": "",
|
||||
"buildCommandArgs": "",
|
||||
"ctestCommandArgs": "",
|
||||
"inheritEnvironments": [ "linux_x64" ],
|
||||
"wslPath": "${defaultWSLPath}"
|
||||
}
|
||||
]
|
||||
}
|
13
README.md
13
README.md
|
@ -87,4 +87,17 @@ Please see the following pages on using the compiled code:
|
|||
- [Commandline interface for building and search SSD based indices](workflows/SSD_index.md)
|
||||
- [Commandline interface for building and search in memory indices](workflows/in_memory_index.md)
|
||||
- [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md)
|
||||
- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md)
|
||||
- To be added: Python interfaces and docker files
|
||||
|
||||
Please cite this software in your work as:
|
||||
|
||||
```
|
||||
@misc{diskann-github,
|
||||
author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan}},
|
||||
title = {{DiskANN: Scalable, efficient and Feature-rich ANNS}},
|
||||
url = {https://github.com/Microsoft/DiskANN},
|
||||
version = {0.5},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -40,7 +40,7 @@ namespace diskann {
|
|||
const uint32_t WARMUP_L = 20;
|
||||
const uint32_t NUM_KMEANS_REPS = 12;
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT>
|
||||
class PQFlashIndex;
|
||||
|
||||
DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str);
|
||||
|
@ -68,38 +68,47 @@ namespace diskann {
|
|||
uint64_t warmup_aligned_dim);
|
||||
#endif
|
||||
|
||||
DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix,
|
||||
const std::string &vamana_suffix,
|
||||
const std::string &idmaps_prefix,
|
||||
const std::string &idmaps_suffix,
|
||||
const _u64 nshards, unsigned max_degree,
|
||||
const std::string &output_vamana,
|
||||
const std::string &medoids_file);
|
||||
DISKANN_DLLEXPORT int merge_shards(
|
||||
const std::string &vamana_prefix, const std::string &vamana_suffix,
|
||||
const std::string &idmaps_prefix, const std::string &idmaps_suffix,
|
||||
const _u64 nshards, unsigned max_degree, const std::string &output_vamana,
|
||||
const std::string &medoids_file, bool use_filters = false,
|
||||
const std::string &labels_to_medoids_file = std::string(""));
|
||||
|
||||
DISKANN_DLLEXPORT void extract_shard_labels(
|
||||
const std::string &in_label_file, const std::string &shard_ids_bin,
|
||||
const std::string &shard_label_file);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT std::string preprocess_base_file(
|
||||
const std::string &infile, const std::string &indexPrefix,
|
||||
diskann::Metric &distMetric);
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
DISKANN_DLLEXPORT int build_merged_vamana_index(
|
||||
std::string base_file, diskann::Metric _compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_file,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters = false, const std::string &label_file = std::string(""),
|
||||
const std::string &labels_to_medoids_file = std::string(""),
|
||||
const std::string &universal_label = "", const _u32 Lf = 0);
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT>
|
||||
DISKANN_DLLEXPORT uint32_t optimize_beamwidth(
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> &_pFlashIndex, T *tuning_sample,
|
||||
_u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
|
||||
uint32_t nthreads, uint32_t start_bw = 2);
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &_pFlashIndex,
|
||||
T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim,
|
||||
uint32_t L, uint32_t nthreads, uint32_t start_bw = 2);
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath,
|
||||
const char *indexFilePath,
|
||||
const char *indexBuildParameters,
|
||||
diskann::Metric _compareMetric,
|
||||
bool use_opq = false);
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
DISKANN_DLLEXPORT int build_disk_index(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric _compareMetric,
|
||||
bool use_opq = false, bool use_filters = false,
|
||||
const std::string &label_file =
|
||||
std::string(""), // default is empty string for no label_file
|
||||
const std::string &universal_label = "", const _u32 filter_threshold = 0,
|
||||
const _u32 Lf = 0); // default is empty string for no universal label
|
||||
|
||||
template<typename T>
|
||||
DISKANN_DLLEXPORT void create_disk_layout(
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#define DEFAULT_MAXC 750
|
||||
|
||||
namespace diskann {
|
||||
|
||||
inline double estimate_ram_usage(_u64 size, _u32 dim, _u32 datasize,
|
||||
_u32 degree) {
|
||||
double size_of_data = ((double) size) * ROUND_UP(dim, 8) * datasize;
|
||||
|
@ -60,7 +61,7 @@ namespace diskann {
|
|||
}
|
||||
};
|
||||
|
||||
template<typename T, typename TagT = uint32_t>
|
||||
template<typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
class Index {
|
||||
/**************************************************************************
|
||||
*
|
||||
|
@ -129,6 +130,17 @@ namespace diskann {
|
|||
Parameters ¶meters,
|
||||
const std::vector<TagT> &tags);
|
||||
|
||||
// Filtered Support
|
||||
DISKANN_DLLEXPORT void build_filtered_index(
|
||||
const char *filename, const std::string &label_file,
|
||||
const size_t num_points_to_load, Parameters ¶meters,
|
||||
const std::vector<TagT> &tags = std::vector<TagT>());
|
||||
|
||||
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
|
||||
|
||||
// Get converted integer label from string to int map (_label_map)
|
||||
DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label);
|
||||
|
||||
// Set starting point of an index before inserting any points incrementally
|
||||
DISKANN_DLLEXPORT void set_start_point(T *data);
|
||||
// Set starting point to a random point on a sphere of certain radius
|
||||
|
@ -155,6 +167,12 @@ namespace diskann {
|
|||
float *distances,
|
||||
std::vector<T *> &res_vectors);
|
||||
|
||||
// Filter support search
|
||||
template<typename IndexType>
|
||||
DISKANN_DLLEXPORT std::pair<uint32_t, uint32_t> search_with_filters(
|
||||
const T *query, const LabelT &filter_label, const size_t K,
|
||||
const unsigned L, IndexType *indices, float *distances);
|
||||
|
||||
// Will fail if tag already in the index or if tag=0.
|
||||
DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag);
|
||||
|
||||
|
@ -177,6 +195,8 @@ namespace diskann {
|
|||
DISKANN_DLLEXPORT consolidation_report
|
||||
consolidate_deletes(const Parameters ¶meters);
|
||||
|
||||
DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters ¶meters);
|
||||
|
||||
DISKANN_DLLEXPORT bool is_index_saved();
|
||||
|
||||
// repositions frozen points to the end of _data - if they have been moved
|
||||
|
@ -208,8 +228,8 @@ namespace diskann {
|
|||
|
||||
protected:
|
||||
// No copy/assign.
|
||||
Index(const Index<T, TagT> &) = delete;
|
||||
Index<T, TagT> &operator=(const Index<T, TagT> &) = delete;
|
||||
Index(const Index<T, TagT, LabelT> &) = delete;
|
||||
Index<T, TagT, LabelT> &operator=(const Index<T, TagT, LabelT> &) = delete;
|
||||
|
||||
// Use after _data and _nd have been populated
|
||||
// Acquire exclusive _update_lock before calling
|
||||
|
@ -223,14 +243,23 @@ namespace diskann {
|
|||
// determines navigating node of the graph by calculating medoid of datafopt
|
||||
unsigned calculate_entry_point();
|
||||
|
||||
void parse_label_file(const std::string &label_file,
|
||||
size_t &num_pts_labels);
|
||||
|
||||
std::unordered_map<std::string, LabelT> load_label_map(
|
||||
const std::string &map_file);
|
||||
|
||||
std::pair<uint32_t, uint32_t> iterate_to_fixed_point(
|
||||
const T *node_coords, const unsigned Lindex,
|
||||
const std::vector<unsigned> &init_ids, InMemQueryScratch<T> *scratch,
|
||||
bool use_filter, const std::vector<LabelT> &filters,
|
||||
bool ret_frozen = true, bool search_invocation = false);
|
||||
|
||||
void search_for_point_and_prune(int location, _u32 Lindex,
|
||||
std::vector<unsigned> &pruned_list,
|
||||
InMemQueryScratch<T> *scratch);
|
||||
InMemQueryScratch<T> *scratch,
|
||||
bool use_filter = false,
|
||||
_u32 filteredLindex = 0);
|
||||
|
||||
void prune_neighbors(const unsigned location, std::vector<Neighbor> &pool,
|
||||
std::vector<unsigned> &pruned_list,
|
||||
|
@ -342,6 +371,19 @@ namespace diskann {
|
|||
bool _enable_tags = false;
|
||||
bool _normalize_vecs = false; // Using normalied L2 for cosine.
|
||||
|
||||
// Filter Support
|
||||
|
||||
bool _filtered_index = false;
|
||||
std::vector<std::vector<LabelT>> _pts_to_labels;
|
||||
tsl::robin_set<LabelT> _labels;
|
||||
std::string _labels_file;
|
||||
std::unordered_map<LabelT, _u32> _label_to_medoid_id;
|
||||
std::unordered_map<_u32, _u32> _medoid_counts;
|
||||
bool _use_universal_label = false;
|
||||
LabelT _universal_label = 0;
|
||||
uint32_t _filterIndexingQueueSize;
|
||||
std::unordered_map<std::string, LabelT> _label_map;
|
||||
|
||||
// Indexing parameters
|
||||
uint32_t _indexingQueueSize;
|
||||
uint32_t _indexingRange;
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
namespace diskann {
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
class PQFlashIndex {
|
||||
public:
|
||||
DISKANN_DLLEXPORT PQFlashIndex(
|
||||
|
@ -70,11 +70,26 @@ namespace diskann {
|
|||
float *res_dists, const _u64 beam_width,
|
||||
const bool use_reorder_data = false, QueryStats *stats = nullptr);
|
||||
|
||||
DISKANN_DLLEXPORT void cached_beam_search(
|
||||
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
|
||||
float *res_dists, const _u64 beam_width, const bool use_filter,
|
||||
const LabelT &filter_label, const bool use_reorder_data = false,
|
||||
QueryStats *stats = nullptr);
|
||||
|
||||
DISKANN_DLLEXPORT void cached_beam_search(
|
||||
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
|
||||
float *res_dists, const _u64 beam_width, const _u32 io_limit,
|
||||
const bool use_reorder_data = false, QueryStats *stats = nullptr);
|
||||
|
||||
DISKANN_DLLEXPORT void cached_beam_search(
|
||||
const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids,
|
||||
float *res_dists, const _u64 beam_width, const bool use_filter,
|
||||
const LabelT &filter_label, const _u32 io_limit,
|
||||
const bool use_reorder_data = false, QueryStats *stats = nullptr);
|
||||
|
||||
DISKANN_DLLEXPORT LabelT
|
||||
get_converted_label(const std::string &filter_label);
|
||||
|
||||
DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range,
|
||||
const _u64 min_l_search,
|
||||
const _u64 max_l_search,
|
||||
|
@ -94,12 +109,26 @@ namespace diskann {
|
|||
DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads,
|
||||
_u64 visited_reserve = 4096);
|
||||
|
||||
DISKANN_DLLEXPORT void set_universal_label(const LabelT &label);
|
||||
|
||||
private:
|
||||
DISKANN_DLLEXPORT inline bool point_has_label(_u32 point_id, _u32 label_id);
|
||||
std::unordered_map<std::string, LabelT> load_label_map(
|
||||
const std::string &map_file);
|
||||
DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file,
|
||||
size_t &num_pts_labels);
|
||||
DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file,
|
||||
_u32 &num_pts,
|
||||
_u32 &num_total_labels);
|
||||
DISKANN_DLLEXPORT inline int32_t get_filter_number(
|
||||
const LabelT &filter_label);
|
||||
|
||||
// index info
|
||||
// nhood of node `i` is in sector: [i / nnodes_per_sector]
|
||||
// offset in sector: [(i % nnodes_per_sector) * max_node_len]
|
||||
// nnbrs of node `i`: *(unsigned*) (buf)
|
||||
// nbrs of node `i`: ((unsigned*)buf) + 1
|
||||
|
||||
_u64 max_node_len = 0, nnodes_per_sector = 0, max_degree = 0;
|
||||
|
||||
// Data used for searching with re-order vectors
|
||||
|
@ -171,6 +200,20 @@ namespace diskann {
|
|||
bool reorder_data_exists = false;
|
||||
_u64 reoreder_data_offset = 0;
|
||||
|
||||
// filter support
|
||||
_u32 *_pts_to_label_offsets = nullptr;
|
||||
_u32 *_pts_to_labels = nullptr;
|
||||
tsl::robin_set<LabelT> _labels;
|
||||
std::unordered_map<LabelT, _u32> _filter_to_medoid_id;
|
||||
bool _use_universal_label;
|
||||
_u32 _universal_filter_num;
|
||||
std::vector<LabelT> _filter_list;
|
||||
tsl::robin_set<_u32> _dummy_pts;
|
||||
tsl::robin_set<_u32> _has_dummy_pts;
|
||||
tsl::robin_map<_u32, _u32> _dummy_to_real_map;
|
||||
tsl::robin_map<_u32, std::vector<_u32>> _real_to_dummy_map;
|
||||
std::unordered_map<std::string, LabelT> _label_map;
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
// Set to a larger value than the actual header to accommodate
|
||||
// any additions we make to the header. This is an outer limit
|
||||
|
|
|
@ -31,19 +31,24 @@ typedef int FileHandle;
|
|||
#include "memory_mapped_files.h"
|
||||
#endif
|
||||
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
|
||||
// taken from
|
||||
// https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h
|
||||
// round up X to the nearest multiple of Y
|
||||
#define ROUND_UP(X, Y) \
|
||||
((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y))
|
||||
((((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0)) * (Y))
|
||||
|
||||
#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0))
|
||||
#define DIV_ROUND_UP(X, Y) \
|
||||
(((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0))
|
||||
|
||||
// round down X to the nearest multiple of Y
|
||||
#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y))
|
||||
#define ROUND_DOWN(X, Y) (((uint64_t) (X) / (Y)) * (Y))
|
||||
|
||||
// alignment tests
|
||||
#define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0)
|
||||
#define IS_ALIGNED(X, Y) ((uint64_t) (X) % (uint64_t) (Y) == 0)
|
||||
#define IS_512_ALIGNED(X) IS_ALIGNED(X, 512)
|
||||
#define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096)
|
||||
#define METADATA_SIZE \
|
||||
|
@ -92,8 +97,9 @@ typedef uint16_t _u16;
|
|||
typedef int16_t _s16;
|
||||
typedef uint8_t _u8;
|
||||
typedef int8_t _s8;
|
||||
inline void open_file_to_write(std::ofstream& writer,
|
||||
const std::string& filename) {
|
||||
|
||||
inline void open_file_to_write(std::ofstream& writer,
|
||||
const std::string& filename) {
|
||||
writer.exceptions(std::ofstream::failbit | std::ofstream::badbit);
|
||||
if (!file_exists(filename))
|
||||
writer.open(filename, std::ios::binary | std::ios::out);
|
||||
|
@ -144,6 +150,48 @@ inline int delete_file(const std::string& fileName) {
|
|||
}
|
||||
}
|
||||
|
||||
inline void convert_labels_string_to_int(const std::string& inFileName,
|
||||
const std::string& outFileName,
|
||||
const std::string& mapFileName,
|
||||
const std::string& unv_label) {
|
||||
std::unordered_map<std::string, _u32> string_int_map;
|
||||
std::ofstream label_writer(outFileName);
|
||||
std::ifstream label_reader(inFileName);
|
||||
if (unv_label != "")
|
||||
string_int_map[unv_label] = 0;
|
||||
std::string line, token;
|
||||
while (std::getline(label_reader, line)) {
|
||||
std::istringstream new_iss(line);
|
||||
std::vector<_u32> lbls;
|
||||
while (getline(new_iss, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
if (string_int_map.find(token) == string_int_map.end()) {
|
||||
_u32 nextId = (_u32) string_int_map.size() + 1;
|
||||
string_int_map[token] = nextId;
|
||||
}
|
||||
lbls.push_back(string_int_map[token]);
|
||||
}
|
||||
if (lbls.size() <= 0) {
|
||||
std::cout << "No label found";
|
||||
exit(-1);
|
||||
}
|
||||
for (size_t j = 0; j < lbls.size(); j++) {
|
||||
if (j != lbls.size() - 1)
|
||||
label_writer << lbls[j] << ",";
|
||||
else
|
||||
label_writer << lbls[j] << std::endl;
|
||||
}
|
||||
}
|
||||
label_writer.close();
|
||||
|
||||
std::ofstream map_writer(mapFileName);
|
||||
for (auto mp : string_int_map) {
|
||||
map_writer << mp.first << "\t" << mp.second << std::endl;
|
||||
}
|
||||
map_writer.close();
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
class AlignedFileReader;
|
||||
#endif
|
||||
|
@ -568,6 +616,19 @@ namespace diskann {
|
|||
}
|
||||
#endif
|
||||
|
||||
inline void copy_file(std::string in_file, std::string out_file) {
|
||||
std::ifstream source(in_file, std::ios::binary);
|
||||
std::ofstream dest(out_file, std::ios::binary);
|
||||
|
||||
std::istreambuf_iterator<char> begin_source(source);
|
||||
std::istreambuf_iterator<char> end_source;
|
||||
std::ostreambuf_iterator<char> begin_dest(dest);
|
||||
std::copy(begin_source, end_source, begin_dest);
|
||||
|
||||
source.close();
|
||||
dest.close();
|
||||
}
|
||||
|
||||
DISKANN_DLLEXPORT double calculate_recall(
|
||||
unsigned num_queries, unsigned* gold_std, float* gs_dist, unsigned dim_gs,
|
||||
unsigned* our_results, unsigned dim_or, unsigned recall_at);
|
||||
|
@ -947,7 +1008,7 @@ inline void normalize(T* arr, size_t dim) {
|
|||
}
|
||||
sum = sqrt(sum);
|
||||
for (uint32_t i = 0; i < dim; i++) {
|
||||
arr[i] = (T)(arr[i] / sum);
|
||||
arr[i] = (T) (arr[i] / sum);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,13 +25,12 @@ PYBIND11_MAKE_OPAQUE(std::vector<float>);
|
|||
PYBIND11_MAKE_OPAQUE(std::vector<int8_t>);
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<uint8_t>);
|
||||
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace diskann;
|
||||
|
||||
template<class T>
|
||||
struct DiskANNIndex {
|
||||
PQFlashIndex<T> * pq_flash_index;
|
||||
PQFlashIndex<T> *pq_flash_index;
|
||||
std::shared_ptr<AlignedFileReader> reader;
|
||||
|
||||
DiskANNIndex(diskann::Metric metric) {
|
||||
|
@ -74,8 +73,7 @@ struct DiskANNIndex {
|
|||
const size_t num_nodes_to_cache, int cache_mechanism) {
|
||||
const std::string index_path =
|
||||
index_path_prefix + std::string("_disk.index");
|
||||
int load_success =
|
||||
pq_flash_index->load(num_threads, index_path.c_str());
|
||||
int load_success = pq_flash_index->load(num_threads, index_path.c_str());
|
||||
if (load_success != 0) {
|
||||
return load_success;
|
||||
}
|
||||
|
@ -197,8 +195,8 @@ struct DiskANNIndex {
|
|||
const int num_threads) {
|
||||
py::array_t<unsigned> offsets(num_queries + 1);
|
||||
|
||||
std::vector<std::vector<_u64> > u64_ids(num_queries);
|
||||
std::vector<std::vector<float> > dists(num_queries);
|
||||
std::vector<std::vector<_u64>> u64_ids(num_queries);
|
||||
std::vector<std::vector<float>> dists(num_queries);
|
||||
|
||||
auto offsets_mutable = offsets.mutable_unchecked();
|
||||
offsets_mutable(0) = 0;
|
||||
|
@ -246,8 +244,9 @@ struct DiskANNIndex {
|
|||
stats, num_queries,
|
||||
[](const diskann::QueryStats &stats) { return stats.n_cmps; });
|
||||
delete[] stats;
|
||||
return std::make_pair(std::make_pair(offsets, std::make_pair(ids, res_dists)),
|
||||
collective_stats);
|
||||
return std::make_pair(
|
||||
std::make_pair(offsets, std::make_pair(ids, res_dists)),
|
||||
collective_stats);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -259,16 +258,15 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
m.attr("__version__") = "dev";
|
||||
#endif
|
||||
|
||||
py::bind_vector<std::vector<unsigned> >(m, "VectorUnsigned");
|
||||
py::bind_vector<std::vector<float> >(m, "VectorFloat");
|
||||
py::bind_vector<std::vector<int8_t> >(m, "VectorInt8");
|
||||
py::bind_vector<std::vector<uint8_t> >(m, "VectorUInt8");
|
||||
|
||||
py::bind_vector<std::vector<unsigned>>(m, "VectorUnsigned");
|
||||
py::bind_vector<std::vector<float>>(m, "VectorFloat");
|
||||
py::bind_vector<std::vector<int8_t>>(m, "VectorInt8");
|
||||
py::bind_vector<std::vector<uint8_t>>(m, "VectorUInt8");
|
||||
|
||||
py::enum_<Metric>(m, "Metric")
|
||||
.value("L2", Metric::L2)
|
||||
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
|
||||
.export_values();
|
||||
.value("L2", Metric::L2)
|
||||
.value("INNER_PRODUCT", Metric::INNER_PRODUCT)
|
||||
.export_values();
|
||||
|
||||
py::class_<Parameters>(m, "Parameters")
|
||||
.def(py::init<>())
|
||||
|
@ -294,9 +292,11 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
py::class_<AlignedFileReader>(m, "AlignedFileReader");
|
||||
|
||||
#ifdef _WINDOWS
|
||||
py::class_<WindowsAlignedFileReader>(m, "WindowsAlignedFileReader").def(py::init<>());
|
||||
py::class_<WindowsAlignedFileReader>(m, "WindowsAlignedFileReader")
|
||||
.def(py::init<>());
|
||||
#else
|
||||
py::class_<LinuxAlignedFileReader>(m, "LinuxAlignedFileReader").def(py::init<>());
|
||||
py::class_<LinuxAlignedFileReader>(m, "LinuxAlignedFileReader")
|
||||
.def(py::init<>());
|
||||
#endif
|
||||
|
||||
m.def(
|
||||
|
@ -327,7 +327,7 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
[](const std::string &path, std::vector<unsigned> &ids,
|
||||
std::vector<float> &distances) {
|
||||
unsigned *id_ptr = nullptr;
|
||||
float * dist_ptr = nullptr;
|
||||
float *dist_ptr = nullptr;
|
||||
size_t num, dims;
|
||||
load_truthset(path, id_ptr, dist_ptr, num, dims);
|
||||
// TODO: Remove redundant copies.
|
||||
|
@ -349,7 +349,7 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
const unsigned ground_truth_dims, std::vector<unsigned> &results,
|
||||
const unsigned result_dims, const unsigned recall_at) {
|
||||
unsigned *gti_ptr = ground_truth_ids.data();
|
||||
float * gtd_ptr = ground_truth_dists.data();
|
||||
float *gtd_ptr = ground_truth_dists.data();
|
||||
unsigned *r_ptr = results.data();
|
||||
|
||||
double total_recall = 0;
|
||||
|
@ -390,10 +390,10 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
std::vector<float> &ground_truth_dists,
|
||||
const unsigned ground_truth_dims,
|
||||
py::array_t<unsigned, py::array::c_style | py::array::forcecast>
|
||||
& results,
|
||||
&results,
|
||||
const unsigned result_dims, const unsigned recall_at) {
|
||||
unsigned *gti_ptr = ground_truth_ids.data();
|
||||
float * gtd_ptr = ground_truth_dists.data();
|
||||
float *gtd_ptr = ground_truth_dists.data();
|
||||
unsigned *r_ptr = results.mutable_data();
|
||||
|
||||
double total_recall = 0;
|
||||
|
@ -434,9 +434,9 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
size_t dims) { save_bin<_u32>(file_name, data.data(), npts, dims); },
|
||||
py::arg("file_name"), py::arg("data"), py::arg("npts"), py::arg("dims"));
|
||||
|
||||
py::class_<DiskANNIndex<float> >(m, "DiskANNFloatIndex")
|
||||
py::class_<DiskANNIndex<float>>(m, "DiskANNFloatIndex")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<float> >(
|
||||
return std::unique_ptr<DiskANNIndex<float>>(
|
||||
new DiskANNIndex<float>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<float>::cache_bfs_levels,
|
||||
|
@ -462,8 +462,8 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
.def("batch_range_search_numpy_input",
|
||||
&DiskANNIndex<float>::batch_range_search_numpy_input,
|
||||
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
|
||||
py::arg("num_threads"))
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
|
||||
py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<float> &self, const char *data_file_path,
|
||||
|
@ -485,13 +485,13 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
py::arg("indexing_ram_limit"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes") = 0);
|
||||
|
||||
py::class_<DiskANNIndex<int8_t> >(m, "DiskANNInt8Index")
|
||||
py::class_<DiskANNIndex<int8_t>>(m, "DiskANNInt8Index")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<int8_t> >(
|
||||
return std::unique_ptr<DiskANNIndex<int8_t>>(
|
||||
new DiskANNIndex<int8_t>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<int8_t>::cache_bfs_levels,
|
||||
py::arg("num_nodes_to_cache"))
|
||||
py::arg("num_nodes_to_cache"))
|
||||
.def("load_index", &DiskANNIndex<int8_t>::load_index,
|
||||
py::arg("index_path_prefix"), py::arg("num_threads"),
|
||||
py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1)
|
||||
|
@ -513,8 +513,8 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
.def("batch_range_search_numpy_input",
|
||||
&DiskANNIndex<int8_t>::batch_range_search_numpy_input,
|
||||
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
|
||||
py::arg("num_threads"))
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
|
||||
py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<int8_t> &self, const char *data_file_path,
|
||||
|
@ -536,10 +536,9 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
py::arg("indexing_ram_limit"), py::arg("num_threads"),
|
||||
py::arg("pq_disk_bytes") = 0);
|
||||
|
||||
|
||||
py::class_<DiskANNIndex<uint8_t> >(m, "DiskANNUInt8Index")
|
||||
py::class_<DiskANNIndex<uint8_t>>(m, "DiskANNUInt8Index")
|
||||
.def(py::init([](diskann::Metric metric) {
|
||||
return std::unique_ptr<DiskANNIndex<uint8_t> >(
|
||||
return std::unique_ptr<DiskANNIndex<uint8_t>>(
|
||||
new DiskANNIndex<uint8_t>(metric));
|
||||
}))
|
||||
.def("cache_bfs_levels", &DiskANNIndex<uint8_t>::cache_bfs_levels,
|
||||
|
@ -565,8 +564,8 @@ PYBIND11_MODULE(diskannpy, m) {
|
|||
.def("batch_range_search_numpy_input",
|
||||
&DiskANNIndex<uint8_t>::batch_range_search_numpy_input,
|
||||
py::arg("queries"), py::arg("dim"), py::arg("num_queries"),
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"),
|
||||
py::arg("num_threads"))
|
||||
py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"),
|
||||
py::arg("beam_width"), py::arg("num_threads"))
|
||||
.def(
|
||||
"build",
|
||||
[](DiskANNIndex<uint8_t> &self, const char *data_file_path,
|
||||
|
|
|
@ -242,7 +242,8 @@ namespace diskann {
|
|||
const std::string &idmaps_prefix,
|
||||
const std::string &idmaps_suffix, const _u64 nshards,
|
||||
unsigned max_degree, const std::string &output_vamana,
|
||||
const std::string &medoids_file) {
|
||||
const std::string &medoids_file, bool use_filters,
|
||||
const std::string &labels_to_medoids_file) {
|
||||
// Read ID maps
|
||||
std::vector<std::string> vamana_names(nshards);
|
||||
std::vector<std::vector<unsigned>> idmaps(nshards);
|
||||
|
@ -283,6 +284,57 @@ namespace diskann {
|
|||
});
|
||||
diskann::cout << "Finished computing node -> shards map" << std::endl;
|
||||
|
||||
// will merge all the labels to medoids files of each shard into one
|
||||
// combined file
|
||||
if (use_filters) {
|
||||
std::unordered_map<unsigned, std::vector<_u32>> global_label_to_medoids;
|
||||
|
||||
for (_u64 i = 0; i < nshards; i++) {
|
||||
std::ifstream mapping_reader;
|
||||
std::string map_file = vamana_names[i] + "_labels_to_medoids.txt";
|
||||
mapping_reader.open(map_file);
|
||||
|
||||
std::string line, token;
|
||||
unsigned line_cnt = 0;
|
||||
|
||||
while (std::getline(mapping_reader, line)) {
|
||||
std::istringstream iss(line);
|
||||
_u32 cnt = 0;
|
||||
_u32 medoid;
|
||||
_u32 label;
|
||||
while (std::getline(iss, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'),
|
||||
token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'),
|
||||
token.end());
|
||||
|
||||
unsigned token_as_num = std::stoul(token);
|
||||
|
||||
if (cnt == 0)
|
||||
label = token_as_num;
|
||||
else
|
||||
medoid = token_as_num;
|
||||
cnt++;
|
||||
}
|
||||
global_label_to_medoids[label].push_back(idmaps[i][medoid]);
|
||||
line_cnt++;
|
||||
}
|
||||
mapping_reader.close();
|
||||
}
|
||||
|
||||
std::ofstream mapping_writer(labels_to_medoids_file);
|
||||
assert(mapping_writer.is_open());
|
||||
for (auto iter : global_label_to_medoids) {
|
||||
mapping_writer << iter.first << ", ";
|
||||
auto &vec = iter.second;
|
||||
for (_u32 idx = 0; idx < vec.size() - 1; idx++) {
|
||||
mapping_writer << vec[idx] << ", ";
|
||||
}
|
||||
mapping_writer << vec[vec.size() - 1] << std::endl;
|
||||
}
|
||||
mapping_writer.close();
|
||||
}
|
||||
|
||||
// create cached vamana readers
|
||||
std::vector<cached_ifstream> vamana_readers(nshards);
|
||||
for (_u64 i = 0; i < nshards; i++) {
|
||||
|
@ -384,10 +436,16 @@ namespace diskann {
|
|||
}
|
||||
// read from shard_id ifstream
|
||||
vamana_readers[shard_id].read((char *) &shard_nnbrs, sizeof(unsigned));
|
||||
std::vector<unsigned> shard_nhood(shard_nnbrs);
|
||||
vamana_readers[shard_id].read((char *) shard_nhood.data(),
|
||||
shard_nnbrs * sizeof(unsigned));
|
||||
|
||||
if (shard_nnbrs == 0) {
|
||||
diskann::cout << "WARNING: shard #" << shard_id << ", node_id "
|
||||
<< node_id << " has 0 nbrs" << std::endl;
|
||||
}
|
||||
|
||||
std::vector<unsigned> shard_nhood(shard_nnbrs);
|
||||
if (shard_nnbrs > 0)
|
||||
vamana_readers[shard_id].read((char *) shard_nhood.data(),
|
||||
shard_nnbrs * sizeof(unsigned));
|
||||
// rename nodes
|
||||
for (_u64 j = 0; j < shard_nnbrs; j++) {
|
||||
if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) {
|
||||
|
@ -402,8 +460,10 @@ namespace diskann {
|
|||
nnbrs = (unsigned) (std::min)(final_nhood.size(), (uint64_t) max_degree);
|
||||
// write into merged ofstream
|
||||
merged_vamana_writer.write((char *) &nnbrs, sizeof(unsigned));
|
||||
merged_vamana_writer.write((char *) final_nhood.data(),
|
||||
nnbrs * sizeof(unsigned));
|
||||
if (nnbrs > 0) {
|
||||
merged_vamana_writer.write((char *) final_nhood.data(),
|
||||
nnbrs * sizeof(unsigned));
|
||||
}
|
||||
merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned));
|
||||
for (auto &p : final_nhood)
|
||||
nhood_set[p] = 0;
|
||||
|
@ -418,47 +478,204 @@ namespace diskann {
|
|||
return 0;
|
||||
}
|
||||
|
||||
// TODO: Make this a streaming implementation to avoid exceeding the memory
|
||||
// budget
|
||||
/* If the number of filters per point N exceeds the graph degree R,
|
||||
then it is difficult to have edges to all labels from this point.
|
||||
This function break up such dense points to have only a threshold of maximum
|
||||
T labels per point It divides one graph nodes to multiple nodes and append
|
||||
the new nodes at the end. The dummy map contains the real graph id of the
|
||||
new nodes added to the graph */
|
||||
template<typename T>
|
||||
int build_merged_vamana_index(std::string base_file,
|
||||
diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate,
|
||||
double ram_budget, std::string mem_index_path,
|
||||
std::string medoids_file,
|
||||
std::string centroids_file,
|
||||
size_t build_pq_bytes, bool use_opq) {
|
||||
void breakup_dense_points(const std::string data_file,
|
||||
const std::string labels_file, _u32 density,
|
||||
const std::string out_data_file,
|
||||
const std::string out_labels_file,
|
||||
const std::string out_metadata_file) {
|
||||
std::string token, line;
|
||||
std::ifstream labels_stream(labels_file);
|
||||
T *data;
|
||||
_u64 npts, ndims;
|
||||
diskann::load_bin<T>(data_file, data, npts, ndims);
|
||||
|
||||
std::unordered_map<_u32, _u32> dummy_pt_ids;
|
||||
_u32 next_dummy_id = (_u32) npts;
|
||||
|
||||
_u32 point_cnt = 0;
|
||||
|
||||
std::vector<std::vector<uint32_t>> labels_per_point;
|
||||
labels_per_point.resize(npts);
|
||||
|
||||
_u32 dense_pts = 0;
|
||||
if (labels_stream.is_open()) {
|
||||
while (getline(labels_stream, line)) {
|
||||
std::stringstream iss(line);
|
||||
_u32 lbl_cnt = 0;
|
||||
_u32 label_host = point_cnt;
|
||||
while (getline(iss, token, ',')) {
|
||||
if (lbl_cnt == density) {
|
||||
if (label_host == point_cnt)
|
||||
dense_pts++;
|
||||
label_host = next_dummy_id;
|
||||
labels_per_point.resize(next_dummy_id + 1);
|
||||
dummy_pt_ids[next_dummy_id] = (_u32) point_cnt;
|
||||
next_dummy_id++;
|
||||
lbl_cnt = 0;
|
||||
}
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'),
|
||||
token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'),
|
||||
token.end());
|
||||
unsigned token_as_num = std::stoul(token);
|
||||
labels_per_point[label_host].push_back(token_as_num);
|
||||
lbl_cnt++;
|
||||
}
|
||||
point_cnt++;
|
||||
}
|
||||
}
|
||||
diskann::cout << "fraction of dense points with >= " << density
|
||||
<< " labels = " << (float) dense_pts / (float) npts
|
||||
<< std::endl;
|
||||
|
||||
if (labels_per_point.size() != 0) {
|
||||
diskann::cout << labels_per_point.size() << " is the new number of points"
|
||||
<< std::endl;
|
||||
std::ofstream label_writer(out_labels_file);
|
||||
assert(label_writer.is_open());
|
||||
for (_u32 i = 0; i < labels_per_point.size(); i++) {
|
||||
for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) {
|
||||
label_writer << labels_per_point[i][j] << ",";
|
||||
}
|
||||
if (labels_per_point[i].size() != 0)
|
||||
label_writer << labels_per_point[i][labels_per_point[i].size() - 1];
|
||||
label_writer << std::endl;
|
||||
}
|
||||
label_writer.close();
|
||||
}
|
||||
|
||||
if (dummy_pt_ids.size() != 0) {
|
||||
diskann::cout << dummy_pt_ids.size()
|
||||
<< " is the number of dummy points created" << std::endl;
|
||||
data = (T *) std::realloc((void *) data,
|
||||
labels_per_point.size() * ndims * sizeof(T));
|
||||
std::ofstream dummy_writer(out_metadata_file);
|
||||
assert(dummy_writer.is_open());
|
||||
for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) {
|
||||
dummy_writer << i->first << "," << i->second << std::endl;
|
||||
std::memcpy(data + i->first * ndims, data + i->second * ndims,
|
||||
ndims * sizeof(T));
|
||||
}
|
||||
dummy_writer.close();
|
||||
}
|
||||
|
||||
diskann::save_bin<T>(out_data_file, data, labels_per_point.size(), ndims);
|
||||
}
|
||||
|
||||
void extract_shard_labels(
|
||||
const std::string &in_label_file, const std::string &shard_ids_bin,
|
||||
const std::string &shard_label_file) { // assumes ith row is for ith
|
||||
// point in labels file
|
||||
diskann::cout << "Extracting labels for shard" << std::endl;
|
||||
|
||||
_u32 *ids = nullptr;
|
||||
_u64 num_ids, tmp_dim;
|
||||
diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim);
|
||||
|
||||
_u32 counter = 0, shard_counter = 0;
|
||||
std::string cur_line;
|
||||
|
||||
std::ifstream label_reader(in_label_file);
|
||||
std::ofstream label_writer(shard_label_file);
|
||||
assert(label_reader.is_open());
|
||||
assert(label_reader.is_open());
|
||||
if (label_reader && label_writer) {
|
||||
while (std::getline(label_reader, cur_line)) {
|
||||
if (shard_counter >= num_ids) {
|
||||
break;
|
||||
}
|
||||
if (counter == ids[shard_counter]) {
|
||||
label_writer << cur_line << "\n";
|
||||
shard_counter++;
|
||||
}
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
if (ids != nullptr)
|
||||
delete[] ids;
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
int build_merged_vamana_index(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_file,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf) {
|
||||
size_t base_num, base_dim;
|
||||
diskann::get_bin_metadata(base_file, base_num, base_dim);
|
||||
|
||||
double full_index_ram =
|
||||
estimate_ram_usage(base_num, base_dim, sizeof(T), R);
|
||||
|
||||
// TODO: Make this honest when there is filter support
|
||||
if (full_index_ram < ram_budget * 1024 * 1024 * 1024) {
|
||||
diskann::cout << "Full index fits in RAM budget, should consume at most "
|
||||
<< full_index_ram / (1024 * 1024 * 1024)
|
||||
<< "GiBs, so building in one shot" << std::endl;
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("L", (unsigned) L);
|
||||
paras.Set<unsigned>("Lf", (unsigned) Lf);
|
||||
paras.Set<unsigned>("R", (unsigned) R);
|
||||
paras.Set<unsigned>("C", 750);
|
||||
paras.Set<float>("alpha", 1.2f);
|
||||
paras.Set<unsigned>("num_rnds", 2);
|
||||
paras.Set<bool>("saturate_graph", 1);
|
||||
if (!use_filters)
|
||||
paras.Set<bool>("saturate_graph", 1);
|
||||
else
|
||||
paras.Set<bool>("saturate_graph", 0);
|
||||
using TagT = uint32_t;
|
||||
paras.Set<std::string>("save_path", mem_index_path);
|
||||
|
||||
std::unique_ptr<diskann::Index<T>> _pvamanaIndex =
|
||||
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
|
||||
compareMetric, base_dim, base_num, false, false, false,
|
||||
build_pq_bytes > 0, build_pq_bytes, use_opq));
|
||||
_pvamanaIndex->build(base_file.c_str(), base_num, paras);
|
||||
|
||||
std::unique_ptr<diskann::Index<T, TagT, LabelT>> _pvamanaIndex =
|
||||
std::unique_ptr<diskann::Index<T, TagT, LabelT>>(
|
||||
new diskann::Index<T, TagT, LabelT>(
|
||||
compareMetric, base_dim, base_num, false, false, false,
|
||||
build_pq_bytes > 0, build_pq_bytes, use_opq));
|
||||
if (!use_filters)
|
||||
_pvamanaIndex->build(base_file.c_str(), base_num, paras);
|
||||
else {
|
||||
if (universal_label != "") { // indicates no universal label
|
||||
LabelT unv_label_as_num = 0;
|
||||
_pvamanaIndex->set_universal_label(unv_label_as_num);
|
||||
}
|
||||
_pvamanaIndex->build_filtered_index(base_file.c_str(), label_file,
|
||||
base_num, paras);
|
||||
}
|
||||
_pvamanaIndex->save(mem_index_path.c_str());
|
||||
|
||||
if (use_filters) {
|
||||
// need to copy the labels_to_medoids file to the specified input file
|
||||
std::remove(labels_to_medoids_file.c_str());
|
||||
std::string mem_labels_to_medoid_file =
|
||||
mem_index_path + "_labels_to_medoids.txt";
|
||||
copy_file(mem_labels_to_medoid_file, labels_to_medoids_file);
|
||||
std::remove(mem_labels_to_medoid_file.c_str());
|
||||
}
|
||||
|
||||
std::remove(medoids_file.c_str());
|
||||
std::remove(centroids_file.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
// where the universal label is to be saved in the final graph
|
||||
std::string final_index_universal_label_file =
|
||||
mem_index_path + "_universal_label.txt";
|
||||
|
||||
std::string merged_index_prefix = mem_index_path + "_tempFiles";
|
||||
|
||||
Timer timer;
|
||||
int num_parts =
|
||||
int num_parts =
|
||||
partition_with_ram_budget<T>(base_file, sampling_rate, ram_budget,
|
||||
2 * R / 3, merged_index_prefix, 2);
|
||||
diskann::cout << timer.elapsed_seconds_for_step("partitioning data")
|
||||
|
@ -475,6 +692,9 @@ namespace diskann {
|
|||
std::string shard_ids_file = merged_index_prefix + "_subshard-" +
|
||||
std::to_string(p) + "_ids_uint32.bin";
|
||||
|
||||
std::string shard_labels_file = merged_index_prefix + "_subshard-" +
|
||||
std::to_string(p) + "_labels.txt";
|
||||
|
||||
retrieve_shard_data_from_ids<T>(base_file, shard_ids_file,
|
||||
shard_base_file);
|
||||
|
||||
|
@ -483,6 +703,7 @@ namespace diskann {
|
|||
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>("Lf", Lf);
|
||||
paras.Set<unsigned>("R", (2 * (R / 3)));
|
||||
paras.Set<unsigned>("C", 750);
|
||||
paras.Set<float>("alpha", 1.2f);
|
||||
|
@ -496,17 +717,43 @@ namespace diskann {
|
|||
std::unique_ptr<diskann::Index<T>>(new diskann::Index<T>(
|
||||
compareMetric, shard_base_dim, shard_base_pts, false, false,
|
||||
false, build_pq_bytes > 0, build_pq_bytes, use_opq));
|
||||
_pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras);
|
||||
if (!use_filters) {
|
||||
_pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras);
|
||||
} else {
|
||||
diskann::extract_shard_labels(label_file, shard_ids_file,
|
||||
shard_labels_file);
|
||||
if (universal_label != "") { // indicates no universal label
|
||||
LabelT unv_label_as_num = 0;
|
||||
_pvamanaIndex->set_universal_label(unv_label_as_num);
|
||||
}
|
||||
_pvamanaIndex->build_filtered_index(
|
||||
shard_base_file.c_str(), shard_labels_file, shard_base_pts, paras);
|
||||
}
|
||||
_pvamanaIndex->save(shard_index_file.c_str());
|
||||
// copy universal label file from first shard to the final destination
|
||||
// index, since all shards anyway share the universal label
|
||||
if (p == 0) {
|
||||
std::string shard_universal_label_file =
|
||||
shard_index_file + "_universal_label.txt";
|
||||
if (universal_label != "") {
|
||||
copy_file(shard_universal_label_file,
|
||||
final_index_universal_label_file);
|
||||
}
|
||||
}
|
||||
|
||||
std::remove(shard_base_file.c_str());
|
||||
}
|
||||
diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") << std::endl;
|
||||
diskann::cout << timer.elapsed_seconds_for_step(
|
||||
"building indices on shards")
|
||||
<< std::endl;
|
||||
|
||||
timer.reset();
|
||||
diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index",
|
||||
merged_index_prefix + "_subshard-", "_ids_uint32.bin",
|
||||
num_parts, R, mem_index_path, medoids_file);
|
||||
diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl;
|
||||
num_parts, R, mem_index_path, medoids_file,
|
||||
use_filters, labels_to_medoids_file);
|
||||
diskann::cout << timer.elapsed_seconds_for_step("merging indices")
|
||||
<< std::endl;
|
||||
|
||||
// delete tempFiles
|
||||
for (int p = 0; p < num_parts; p++) {
|
||||
|
@ -514,6 +761,8 @@ namespace diskann {
|
|||
merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin";
|
||||
std::string shard_id_file = merged_index_prefix + "_subshard-" +
|
||||
std::to_string(p) + "_ids_uint32.bin";
|
||||
std::string shard_labels_file = merged_index_prefix + "_subshard-" +
|
||||
std::to_string(p) + "_labels.txt";
|
||||
std::string shard_index_file =
|
||||
merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index";
|
||||
std::string shard_index_file_data = shard_index_file + ".data";
|
||||
|
@ -522,6 +771,17 @@ namespace diskann {
|
|||
std::remove(shard_id_file.c_str());
|
||||
std::remove(shard_index_file.c_str());
|
||||
std::remove(shard_index_file_data.c_str());
|
||||
if (use_filters) {
|
||||
std::string shard_index_label_file = shard_index_file + "_labels.txt";
|
||||
std::string shard_index_univ_label_file =
|
||||
shard_index_file + "_universal_label.txt";
|
||||
std::string shard_index_label_map_file =
|
||||
shard_index_file + "_labels_to_medoids.txt";
|
||||
std::remove(shard_labels_file.c_str());
|
||||
std::remove(shard_index_label_file.c_str());
|
||||
std::remove(shard_index_label_map_file.c_str());
|
||||
std::remove(shard_index_univ_label_file.c_str());
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -530,11 +790,11 @@ namespace diskann {
|
|||
|
||||
// optimizes the beamwidth to maximize QPS for a given L_search subject to
|
||||
// 99.9 latency not blowing up
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT>
|
||||
uint32_t optimize_beamwidth(
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> &pFlashIndex, T *tuning_sample,
|
||||
_u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L,
|
||||
uint32_t nthreads, uint32_t start_bw) {
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> &pFlashIndex,
|
||||
T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim,
|
||||
uint32_t L, uint32_t nthreads, uint32_t start_bw) {
|
||||
uint32_t cur_bw = start_bw;
|
||||
double max_qps = 0;
|
||||
uint32_t best_bw = start_bw;
|
||||
|
@ -799,10 +1059,13 @@ namespace diskann {
|
|||
<< std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT>
|
||||
int build_disk_index(const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters,
|
||||
diskann::Metric compareMetric, bool use_opq) {
|
||||
diskann::Metric compareMetric, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label,
|
||||
const _u32 filter_threshold, const _u32 Lf) {
|
||||
std::stringstream parser;
|
||||
parser << std::string(indexBuildParameters);
|
||||
std::string cur_param;
|
||||
|
@ -863,7 +1126,9 @@ namespace diskann {
|
|||
|
||||
std::string base_file(dataFilePath);
|
||||
std::string data_file_to_use = base_file;
|
||||
std::string labels_file_original = label_file;
|
||||
std::string index_prefix_path(indexFilePath);
|
||||
std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt";
|
||||
std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors_path =
|
||||
index_prefix_path + "_pq_compressed.bin";
|
||||
|
@ -871,6 +1136,19 @@ namespace diskann {
|
|||
std::string disk_index_path = index_prefix_path + "_disk.index";
|
||||
std::string medoids_path = disk_index_path + "_medoids.bin";
|
||||
std::string centroids_path = disk_index_path + "_centroids.bin";
|
||||
|
||||
std::string labels_to_medoids_path =
|
||||
disk_index_path + "_labels_to_medoids.txt";
|
||||
std::string mem_labels_file = mem_index_path + "_labels.txt";
|
||||
std::string disk_labels_file = disk_index_path + "_labels.txt";
|
||||
std::string mem_univ_label_file = mem_index_path + "_universal_label.txt";
|
||||
std::string disk_univ_label_file = disk_index_path + "_universal_label.txt";
|
||||
std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt";
|
||||
std::string dummy_remap_file =
|
||||
disk_index_path +
|
||||
"_dummy_remap.txt"; // remap will be used if we break-up points of high
|
||||
// label-density to create copies
|
||||
|
||||
std::string sample_base_prefix = index_prefix_path + "_sample";
|
||||
// optional, used if disk index file must store pq data
|
||||
std::string disk_pq_pivots_path =
|
||||
|
@ -930,6 +1208,27 @@ namespace diskann {
|
|||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// If there is filter support, we break-up points which have too many labels
|
||||
// into replica dummy points which evenly distribute the filters. The rest
|
||||
// of index build happens on the augmented base and labels
|
||||
std::string augmented_data_file, augmented_labels_file;
|
||||
if (use_filters) {
|
||||
convert_labels_string_to_int(labels_file_original, labels_file_to_use,
|
||||
disk_labels_int_map_file, universal_label);
|
||||
augmented_data_file = index_prefix_path + "_augmented_data.bin";
|
||||
augmented_labels_file = index_prefix_path + "_augmented_labels.txt";
|
||||
if (filter_threshold != 0) {
|
||||
dummy_remap_file = index_prefix_path + "_dummy_remap.txt";
|
||||
breakup_dense_points<T>(
|
||||
data_file_to_use, labels_file_to_use, filter_threshold,
|
||||
augmented_data_file, augmented_labels_file,
|
||||
dummy_remap_file); // RKNOTE: This has large memory footprint, need
|
||||
// to make this streaming
|
||||
data_file_to_use = augmented_data_file;
|
||||
labels_file_to_use = augmented_labels_file;
|
||||
}
|
||||
}
|
||||
|
||||
size_t points_num, dim;
|
||||
|
||||
Timer timer;
|
||||
|
@ -956,7 +1255,8 @@ namespace diskann {
|
|||
generate_quantized_data<T>(data_file_to_use, pq_pivots_path,
|
||||
pq_compressed_vectors_path, compareMetric, p_val,
|
||||
num_pq_chunks, use_opq);
|
||||
diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl;
|
||||
diskann::cout << timer.elapsed_seconds_for_step("generating quantized data")
|
||||
<< std::endl;
|
||||
|
||||
// Gopal. Splitting diskann_dll into separate DLLs for search and build.
|
||||
// This code should only be available in the "build" DLL.
|
||||
|
@ -966,10 +1266,11 @@ namespace diskann {
|
|||
#endif
|
||||
|
||||
timer.reset();
|
||||
diskann::build_merged_vamana_index<T>(
|
||||
diskann::build_merged_vamana_index<T, LabelT>(
|
||||
data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val,
|
||||
indexing_ram_budget, mem_index_path, medoids_path, centroids_path,
|
||||
build_pq_bytes, use_opq);
|
||||
build_pq_bytes, use_opq, use_filters, labels_file_to_use,
|
||||
labels_to_medoids_path, universal_label, Lf);
|
||||
diskann::cout << timer.elapsed_seconds_for_step(
|
||||
"building merged vamana index")
|
||||
<< std::endl;
|
||||
|
@ -997,6 +1298,17 @@ namespace diskann {
|
|||
double sample_sampling_rate = num_sample_points / points_num;
|
||||
gen_random_slice<T>(data_file_to_use.c_str(), sample_base_prefix,
|
||||
sample_sampling_rate);
|
||||
if (use_filters) {
|
||||
copy_file(labels_file_to_use, disk_labels_file);
|
||||
std::remove(mem_labels_file.c_str());
|
||||
if (universal_label != "") {
|
||||
copy_file(mem_univ_label_file, disk_univ_label_file);
|
||||
std::remove(mem_univ_label_file.c_str());
|
||||
}
|
||||
std::remove(augmented_data_file.c_str());
|
||||
std::remove(augmented_labels_file.c_str());
|
||||
std::remove(labels_file_to_use.c_str());
|
||||
}
|
||||
|
||||
std::remove(mem_index_path.c_str());
|
||||
if (use_disk_pq)
|
||||
|
@ -1041,48 +1353,123 @@ namespace diskann {
|
|||
uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim);
|
||||
#endif
|
||||
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<int8_t>> &pFlashIndex,
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint32_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<int8_t, uint32_t>> &pFlashIndex,
|
||||
int8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<uint8_t>> &pFlashIndex,
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint32_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint32_t>> &pFlashIndex,
|
||||
uint8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<float>> &pFlashIndex,
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint32_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<float, uint32_t>> &pFlashIndex,
|
||||
float *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
|
||||
template DISKANN_DLLEXPORT int build_disk_index<int8_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<uint8_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<float>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq);
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<int8_t, uint16_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<int8_t, uint16_t>> &pFlashIndex,
|
||||
int8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<uint8_t, uint16_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<uint8_t, uint16_t>> &pFlashIndex,
|
||||
uint8_t *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
template DISKANN_DLLEXPORT uint32_t optimize_beamwidth<float, uint16_t>(
|
||||
std::unique_ptr<diskann::PQFlashIndex<float, uint16_t>> &pFlashIndex,
|
||||
float *tuning_sample, _u64 tuning_sample_num,
|
||||
_u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads,
|
||||
uint32_t start_bw);
|
||||
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t>(
|
||||
template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint32_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint32_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<float, uint32_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
// LabelT = uint16
|
||||
template DISKANN_DLLEXPORT int build_disk_index<int8_t, uint16_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<uint8_t, uint16_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_disk_index<float, uint16_t>(
|
||||
const char *dataFilePath, const char *indexFilePath,
|
||||
const char *indexBuildParameters, diskann::Metric compareMetric,
|
||||
bool use_opq, bool use_filters, const std::string &label_file,
|
||||
const std::string &universal_label, const _u32 filter_threshold,
|
||||
const _u32 Lf);
|
||||
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint32_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<float>(
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint32_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t>(
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint32_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq);
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
// Label=16_t
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<int8_t, uint16_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<float, uint16_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
template DISKANN_DLLEXPORT int build_merged_vamana_index<uint8_t, uint16_t>(
|
||||
std::string base_file, diskann::Metric compareMetric, unsigned L,
|
||||
unsigned R, double sampling_rate, double ram_budget,
|
||||
std::string mem_index_path, std::string medoids_path,
|
||||
std::string centroids_file, size_t build_pq_bytes, bool use_opq,
|
||||
bool use_filters, const std::string &label_file,
|
||||
const std::string &labels_to_medoids_file,
|
||||
const std::string &universal_label, const _u32 Lf);
|
||||
}; // namespace diskann
|
||||
|
|
1167
src/index.cpp
1167
src/index.cpp
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -29,7 +29,7 @@
|
|||
// block size for reading/ processing large files and matrices in blocks
|
||||
#define BLOCK_SIZE 5000000
|
||||
|
||||
//#define SAVE_INFLATED_PQ true
|
||||
// #define SAVE_INFLATED_PQ true
|
||||
|
||||
template<typename T>
|
||||
void gen_random_slice(const std::string base_file,
|
||||
|
@ -591,9 +591,9 @@ int partition_with_ram_budget(const std::string data_file,
|
|||
train_dim, k_base, cluster_sizes);
|
||||
|
||||
for (auto &p : cluster_sizes) {
|
||||
p = (_u64) (p /
|
||||
sampling_rate); // to account for the fact that p is the size
|
||||
// of the shard over the testing sample.
|
||||
// to account for the fact that p is the size of the shard over the
|
||||
// testing sample.
|
||||
p = (_u64) (p / sampling_rate);
|
||||
double cur_shard_ram_estimate =
|
||||
diskann::estimate_ram_usage(p, train_dim, sizeof(T), graph_degree);
|
||||
|
||||
|
|
|
@ -41,9 +41,9 @@
|
|||
|
||||
namespace diskann {
|
||||
|
||||
template<typename T>
|
||||
PQFlashIndex<T>::PQFlashIndex(std::shared_ptr<AlignedFileReader> &fileReader,
|
||||
diskann::Metric m)
|
||||
template<typename T, typename LabelT>
|
||||
PQFlashIndex<T, LabelT>::PQFlashIndex(
|
||||
std::shared_ptr<AlignedFileReader> &fileReader, diskann::Metric m)
|
||||
: reader(fileReader), metric(m) {
|
||||
if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) {
|
||||
if (std::is_floating_point<T>::value) {
|
||||
|
@ -63,8 +63,8 @@ namespace diskann {
|
|||
this->dist_cmp_float.reset(diskann::get_distance_function<float>(metric));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
PQFlashIndex<T>::~PQFlashIndex() {
|
||||
template<typename T, typename LabelT>
|
||||
PQFlashIndex<T, LabelT>::~PQFlashIndex() {
|
||||
#ifndef EXEC_ENV_OLS
|
||||
if (data != nullptr) {
|
||||
delete[] data;
|
||||
|
@ -86,10 +86,18 @@ namespace diskann {
|
|||
this->reader->deregister_all_threads();
|
||||
reader->close();
|
||||
}
|
||||
if (_pts_to_label_offsets != nullptr) {
|
||||
delete[] _pts_to_label_offsets;
|
||||
}
|
||||
|
||||
if (_pts_to_labels != nullptr) {
|
||||
delete[] _pts_to_labels;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::setup_thread_data(_u64 nthreads, _u64 visited_reserve) {
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::setup_thread_data(_u64 nthreads,
|
||||
_u64 visited_reserve) {
|
||||
diskann::cout << "Setting up thread-specific contexts for nthreads: "
|
||||
<< nthreads << std::endl;
|
||||
// omp parallel for to generate unique thread IDs
|
||||
|
@ -107,8 +115,9 @@ namespace diskann {
|
|||
load_flag = true;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::load_cache_list(std::vector<uint32_t> &node_list) {
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::load_cache_list(
|
||||
std::vector<uint32_t> &node_list) {
|
||||
diskann::cout << "Loading the cache list into memory.." << std::flush;
|
||||
_u64 num_cached_nodes = node_list.size();
|
||||
|
||||
|
@ -180,14 +189,14 @@ namespace diskann {
|
|||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(
|
||||
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
|
||||
_u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads,
|
||||
std::vector<uint32_t> &node_list) {
|
||||
#else
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::generate_cache_list_from_sample_queries(
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::generate_cache_list_from_sample_queries(
|
||||
std::string sample_bin, _u64 l_search, _u64 beamwidth,
|
||||
_u64 num_nodes_to_cache, uint32_t nthreads,
|
||||
std::vector<uint32_t> &node_list) {
|
||||
|
@ -245,10 +254,10 @@ namespace diskann {
|
|||
diskann::aligned_free(samples);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::cache_bfs_levels(_u64 num_nodes_to_cache,
|
||||
std::vector<uint32_t> &node_list,
|
||||
const bool shuffle) {
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::cache_bfs_levels(
|
||||
_u64 num_nodes_to_cache, std::vector<uint32_t> &node_list,
|
||||
const bool shuffle) {
|
||||
std::random_device rng;
|
||||
std::mt19937 urng(rng());
|
||||
|
||||
|
@ -379,8 +388,8 @@ namespace diskann {
|
|||
diskann::cout << "done" << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::use_medoids_data_as_centroids() {
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::use_medoids_data_as_centroids() {
|
||||
if (centroid_data != nullptr)
|
||||
aligned_free(centroid_data);
|
||||
alloc_aligned(((void **) ¢roid_data),
|
||||
|
@ -425,13 +434,170 @@ namespace diskann {
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
inline int32_t PQFlashIndex<T, LabelT>::get_filter_number(
|
||||
const LabelT &filter_label) {
|
||||
int idx = -1;
|
||||
for (_u32 i = 0; i < _filter_list.size(); i++) {
|
||||
if (_filter_list[i] == filter_label) {
|
||||
idx = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return idx;
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
std::unordered_map<std::string, LabelT>
|
||||
PQFlashIndex<T, LabelT>::load_label_map(const std::string &labels_map_file) {
|
||||
std::unordered_map<std::string, LabelT> string_to_int_mp;
|
||||
std::ifstream map_reader(labels_map_file);
|
||||
std::string line, token;
|
||||
LabelT token_as_num;
|
||||
std::string label_str;
|
||||
while (std::getline(map_reader, line)) {
|
||||
std::istringstream iss(line);
|
||||
getline(iss, token, '\t');
|
||||
label_str = token;
|
||||
getline(iss, token, '\t');
|
||||
token_as_num = std::stoul(token);
|
||||
string_to_int_mp[label_str] = token_as_num;
|
||||
}
|
||||
return string_to_int_mp;
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
LabelT PQFlashIndex<T, LabelT>::get_converted_label(
|
||||
const std::string &filter_label) {
|
||||
if (_label_map.find(filter_label) != _label_map.end()) {
|
||||
return _label_map[filter_label];
|
||||
}
|
||||
std::stringstream stream;
|
||||
stream << "Unable to find label in the Label Map";
|
||||
diskann::cerr << stream.str() << std::endl;
|
||||
throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::get_label_file_metadata(
|
||||
std::string map_file, _u32 &num_pts, _u32 &num_total_labels) {
|
||||
std::ifstream infile(map_file);
|
||||
std::string line, token;
|
||||
num_pts = 0;
|
||||
num_total_labels = 0;
|
||||
|
||||
while (std::getline(infile, line)) {
|
||||
std::istringstream iss(line);
|
||||
while (getline(iss, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
num_total_labels++;
|
||||
}
|
||||
num_pts++;
|
||||
}
|
||||
|
||||
diskann::cout << "Labels file metadata: num_points: " << num_pts
|
||||
<< ", #total_labels: " << num_total_labels << std::endl;
|
||||
infile.close();
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
inline bool PQFlashIndex<T, LabelT>::point_has_label(_u32 point_id,
|
||||
_u32 label_id) {
|
||||
_u32 start_vec = _pts_to_label_offsets[point_id];
|
||||
_u32 num_lbls = _pts_to_labels[start_vec];
|
||||
bool ret_val = false;
|
||||
for (_u32 i = 0; i < num_lbls; i++) {
|
||||
if (_pts_to_labels[start_vec + 1 + i] == label_id) {
|
||||
ret_val = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ret_val;
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::parse_label_file(const std::string &label_file,
|
||||
size_t &num_points_labels) {
|
||||
std::ifstream infile(label_file);
|
||||
if (infile.fail()) {
|
||||
throw diskann::ANNException(
|
||||
std::string("Failed to open file ") + label_file, -1);
|
||||
}
|
||||
|
||||
std::string line, token;
|
||||
_u32 line_cnt = 0;
|
||||
|
||||
_u32 num_pts_in_label_file;
|
||||
_u32 num_total_labels;
|
||||
get_label_file_metadata(label_file, num_pts_in_label_file,
|
||||
num_total_labels);
|
||||
|
||||
_pts_to_label_offsets = new _u32[num_pts_in_label_file];
|
||||
_pts_to_labels = new _u32[num_pts_in_label_file + num_total_labels];
|
||||
_u32 counter = 0;
|
||||
|
||||
while (std::getline(infile, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::vector<_u32> lbls(0);
|
||||
|
||||
_pts_to_label_offsets[line_cnt] = counter;
|
||||
_u32 &num_lbls_in_cur_pt = _pts_to_labels[counter];
|
||||
num_lbls_in_cur_pt = 0;
|
||||
counter++;
|
||||
getline(iss, token, '\t');
|
||||
std::istringstream new_iss(token);
|
||||
while (getline(new_iss, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
LabelT token_as_num = std::stoul(token);
|
||||
if (_labels.find(token_as_num) == _labels.end()) {
|
||||
_filter_list.emplace_back(token_as_num);
|
||||
}
|
||||
int32_t filter_num = get_filter_number(token_as_num);
|
||||
if (filter_num == -1) {
|
||||
diskann::cout << "Error!! " << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
_pts_to_labels[counter++] = filter_num;
|
||||
num_lbls_in_cur_pt++;
|
||||
_labels.insert(token_as_num);
|
||||
}
|
||||
|
||||
if (num_lbls_in_cur_pt == 0) {
|
||||
diskann::cout << "No label found for point " << line_cnt << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
line_cnt++;
|
||||
}
|
||||
infile.close();
|
||||
num_points_labels = line_cnt;
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::set_universal_label(const LabelT &label) {
|
||||
int32_t temp_filter_num = get_filter_number(label);
|
||||
if (temp_filter_num == -1) {
|
||||
diskann::cout << "Error, could not find universal label. Exitting."
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
} else {
|
||||
_use_universal_label = true;
|
||||
_universal_filter_num = (_u32) temp_filter_num;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
int PQFlashIndex<T>::load(MemoryMappedFiles &files, uint32_t num_threads,
|
||||
const char *index_prefix) {
|
||||
template<typename T, typename LabelT>
|
||||
int PQFlashIndex<T, LabelT>::load(MemoryMappedFiles &files,
|
||||
uint32_t num_threads,
|
||||
const char *index_prefix) {
|
||||
#else
|
||||
template<typename T>
|
||||
int PQFlashIndex<T>::load(uint32_t num_threads, const char *index_prefix) {
|
||||
template<typename T, typename LabelT>
|
||||
int PQFlashIndex<T, LabelT>::load(uint32_t num_threads,
|
||||
const char *index_prefix) {
|
||||
#endif
|
||||
std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin";
|
||||
std::string pq_compressed_vectors =
|
||||
|
@ -449,14 +615,14 @@ namespace diskann {
|
|||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
int PQFlashIndex<T>::load_from_separate_paths(
|
||||
template<typename T, typename LabelT>
|
||||
int PQFlashIndex<T, LabelT>::load_from_separate_paths(
|
||||
diskann::MemoryMappedFiles &files, uint32_t num_threads,
|
||||
const char *index_filepath, const char *pivots_filepath,
|
||||
const char *compressed_filepath) {
|
||||
#else
|
||||
template<typename T>
|
||||
int PQFlashIndex<T>::load_from_separate_paths(
|
||||
template<typename T, typename LabelT>
|
||||
int PQFlashIndex<T, LabelT>::load_from_separate_paths(
|
||||
uint32_t num_threads, const char *index_filepath,
|
||||
const char *pivots_filepath, const char *compressed_filepath) {
|
||||
#endif
|
||||
|
@ -467,6 +633,15 @@ namespace diskann {
|
|||
std::string centroids_file =
|
||||
std::string(disk_index_file) + "_centroids.bin";
|
||||
|
||||
std::string labels_file = std ::string(disk_index_file) + "_labels.txt";
|
||||
std::string labels_to_medoids =
|
||||
std ::string(disk_index_file) + "_labels_to_medoids.txt";
|
||||
std::string dummy_map_file =
|
||||
std ::string(disk_index_file) + "_dummy_map.txt";
|
||||
std::string labels_map_file =
|
||||
std ::string(disk_index_file) + "_labels_map.txt";
|
||||
size_t num_pts_in_label_file = 0;
|
||||
|
||||
size_t pq_file_dim, pq_file_num_centroids;
|
||||
#ifdef EXEC_ENV_OLS
|
||||
get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim,
|
||||
|
@ -503,6 +678,77 @@ namespace diskann {
|
|||
|
||||
this->num_points = npts_u64;
|
||||
this->n_chunks = nchunks_u64;
|
||||
if (file_exists(labels_file)) {
|
||||
parse_label_file(labels_file, num_pts_in_label_file);
|
||||
assert(num_pts_in_label_file == this->num_points);
|
||||
_label_map = load_label_map(labels_map_file);
|
||||
if (file_exists(labels_to_medoids)) {
|
||||
std::ifstream medoid_stream(labels_to_medoids);
|
||||
assert(medoid_stream.is_open());
|
||||
std::string line, token;
|
||||
|
||||
_filter_to_medoid_id.clear();
|
||||
try {
|
||||
while (std::getline(medoid_stream, line)) {
|
||||
std::istringstream iss(line);
|
||||
_u32 cnt = 0;
|
||||
_u32 medoid;
|
||||
LabelT label;
|
||||
while (std::getline(iss, token, ',')) {
|
||||
if (cnt == 0)
|
||||
label = std::stoul(token);
|
||||
else
|
||||
medoid = (_u32) stoul(token);
|
||||
cnt++;
|
||||
}
|
||||
_filter_to_medoid_id[label] = medoid;
|
||||
}
|
||||
} catch (std::system_error &e) {
|
||||
throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__,
|
||||
__LINE__);
|
||||
}
|
||||
}
|
||||
std::string univ_label_file =
|
||||
std ::string(disk_index_file) + "_universal_label.txt";
|
||||
if (file_exists(univ_label_file)) {
|
||||
std::ifstream universal_label_reader(univ_label_file);
|
||||
assert(universal_label_reader.is_open());
|
||||
std::string univ_label;
|
||||
universal_label_reader >> univ_label;
|
||||
universal_label_reader.close();
|
||||
LabelT label_as_num = std::stoul(univ_label);
|
||||
set_universal_label(label_as_num);
|
||||
}
|
||||
if (file_exists(dummy_map_file)) {
|
||||
std::ifstream dummy_map_stream(dummy_map_file);
|
||||
assert(dummy_map_stream.is_open());
|
||||
std::string line, token;
|
||||
|
||||
while (std::getline(dummy_map_stream, line)) {
|
||||
std::istringstream iss(line);
|
||||
_u32 cnt = 0;
|
||||
_u32 dummy_id;
|
||||
_u32 real_id;
|
||||
while (std::getline(iss, token, ',')) {
|
||||
if (cnt == 0)
|
||||
dummy_id = (_u32) stoul(token);
|
||||
else
|
||||
real_id = (_u32) stoul(token);
|
||||
cnt++;
|
||||
}
|
||||
_dummy_pts.insert(dummy_id);
|
||||
_has_dummy_pts.insert(real_id);
|
||||
_dummy_to_real_map[dummy_id] = real_id;
|
||||
|
||||
if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end())
|
||||
_real_to_dummy_map[real_id] = std::vector<_u32>();
|
||||
|
||||
_real_to_dummy_map[real_id].emplace_back(dummy_id);
|
||||
}
|
||||
dummy_map_stream.close();
|
||||
diskann::cout << "Loaded dummy map" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64);
|
||||
|
@ -541,8 +787,8 @@ namespace diskann {
|
|||
disk_bytes_per_point =
|
||||
disk_pq_n_chunks *
|
||||
sizeof(_u8); // revising disk_bytes_per_point since DISK PQ is used.
|
||||
std::cout << "Disk index uses PQ data compressed down to "
|
||||
<< disk_pq_n_chunks << " bytes per point." << std::endl;
|
||||
diskann::cout << "Disk index uses PQ data compressed down to "
|
||||
<< disk_pq_n_chunks << " bytes per point." << std::endl;
|
||||
}
|
||||
|
||||
// read index metadata
|
||||
|
@ -704,8 +950,8 @@ namespace diskann {
|
|||
float *norm_val;
|
||||
diskann::load_bin<float>(norm_file, norm_val, dumr, dumc);
|
||||
this->max_base_norm = norm_val[0];
|
||||
std::cout << "Setting re-scaling factor of base vectors to "
|
||||
<< this->max_base_norm << std::endl;
|
||||
diskann::cout << "Setting re-scaling factor of base vectors to "
|
||||
<< this->max_base_norm << std::endl;
|
||||
delete[] norm_val;
|
||||
}
|
||||
diskann::cout << "done.." << std::endl;
|
||||
|
@ -742,23 +988,58 @@ namespace diskann {
|
|||
}
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::cached_beam_search(const T *query1, const _u64 k_search,
|
||||
const _u64 l_search, _u64 *indices,
|
||||
float *distances,
|
||||
const _u64 beam_width,
|
||||
const bool use_reorder_data,
|
||||
QueryStats *stats) {
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::cached_beam_search(
|
||||
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
|
||||
float *distances, const _u64 beam_width, const bool use_reorder_data,
|
||||
QueryStats *stats) {
|
||||
cached_beam_search(query1, k_search, l_search, indices, distances,
|
||||
beam_width, std::numeric_limits<_u32>::max(),
|
||||
use_reorder_data, stats);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void PQFlashIndex<T>::cached_beam_search(
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::cached_beam_search(
|
||||
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
|
||||
float *distances, const _u64 beam_width, const bool use_filter,
|
||||
const LabelT &filter_label, const bool use_reorder_data,
|
||||
QueryStats *stats) {
|
||||
cached_beam_search(query1, k_search, l_search, indices, distances,
|
||||
beam_width, use_filter, filter_label,
|
||||
std::numeric_limits<_u32>::max(), use_reorder_data,
|
||||
stats);
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::cached_beam_search(
|
||||
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
|
||||
float *distances, const _u64 beam_width, const _u32 io_limit,
|
||||
const bool use_reorder_data, QueryStats *stats) {
|
||||
LabelT dummy_filter = 0;
|
||||
cached_beam_search(query1, k_search, l_search, indices, distances,
|
||||
beam_width, false, dummy_filter,
|
||||
std::numeric_limits<_u32>::max(), use_reorder_data,
|
||||
stats);
|
||||
}
|
||||
|
||||
template<typename T, typename LabelT>
|
||||
void PQFlashIndex<T, LabelT>::cached_beam_search(
|
||||
const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices,
|
||||
float *distances, const _u64 beam_width, const bool use_filter,
|
||||
const LabelT &filter_label, const _u32 io_limit,
|
||||
const bool use_reorder_data, QueryStats *stats) {
|
||||
int32_t filter_num = 0;
|
||||
if (use_filter) {
|
||||
filter_num = get_filter_number(filter_label);
|
||||
if (filter_num < 0) {
|
||||
if (!_use_universal_label) {
|
||||
return;
|
||||
} else {
|
||||
filter_num = _universal_filter_num;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (beam_width > MAX_N_SECTOR_READS)
|
||||
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS",
|
||||
-1, __FUNCSIG__, __FILE__, __LINE__);
|
||||
|
@ -838,14 +1119,22 @@ namespace diskann {
|
|||
|
||||
_u32 best_medoid = 0;
|
||||
float best_dist = (std::numeric_limits<float>::max)();
|
||||
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
|
||||
float cur_expanded_dist = dist_cmp_float->compare(
|
||||
query_float, centroid_data + aligned_dim * cur_m,
|
||||
(unsigned) aligned_dim);
|
||||
if (cur_expanded_dist < best_dist) {
|
||||
best_medoid = medoids[cur_m];
|
||||
best_dist = cur_expanded_dist;
|
||||
if (!use_filter) {
|
||||
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
|
||||
float cur_expanded_dist = dist_cmp_float->compare(
|
||||
query_float, centroid_data + aligned_dim * cur_m,
|
||||
(unsigned) aligned_dim);
|
||||
if (cur_expanded_dist < best_dist) {
|
||||
best_medoid = medoids[cur_m];
|
||||
best_dist = cur_expanded_dist;
|
||||
}
|
||||
}
|
||||
} else if (_filter_to_medoid_id.find(filter_label) !=
|
||||
_filter_to_medoid_id.end()) {
|
||||
best_medoid = _filter_to_medoid_id[filter_label];
|
||||
} else {
|
||||
throw ANNException("Cannot find medoid for specified filter.", -1,
|
||||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
}
|
||||
|
||||
compute_dists(&best_medoid, 1, dist_scratch);
|
||||
|
@ -963,6 +1252,12 @@ namespace diskann {
|
|||
for (_u64 m = 0; m < nnbrs; ++m) {
|
||||
unsigned id = node_nbrs[m];
|
||||
if (visited.insert(id).second) {
|
||||
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
|
||||
continue;
|
||||
|
||||
if (use_filter && !point_has_label(id, filter_num) &&
|
||||
!point_has_label(id, _universal_filter_num))
|
||||
continue;
|
||||
cmps++;
|
||||
float dist = dist_scratch[m];
|
||||
Neighbor nn(id, dist);
|
||||
|
@ -1024,6 +1319,12 @@ namespace diskann {
|
|||
for (_u64 m = 0; m < nnbrs; ++m) {
|
||||
unsigned id = node_nbrs[m];
|
||||
if (visited.insert(id).second) {
|
||||
if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end())
|
||||
continue;
|
||||
|
||||
if (use_filter && !point_has_label(id, filter_num) &&
|
||||
!point_has_label(id, _universal_filter_num))
|
||||
continue;
|
||||
cmps++;
|
||||
float dist = dist_scratch[m];
|
||||
if (stats != nullptr) {
|
||||
|
@ -1096,6 +1397,11 @@ namespace diskann {
|
|||
// copy k_search values
|
||||
for (_u64 i = 0; i < k_search; i++) {
|
||||
indices[i] = full_retset[i].id;
|
||||
|
||||
if (_dummy_pts.find(indices[i]) != _dummy_pts.end()) {
|
||||
indices[i] = _dummy_to_real_map[indices[i]];
|
||||
}
|
||||
|
||||
if (distances != nullptr) {
|
||||
distances[i] = full_retset[i].distance;
|
||||
if (metric == diskann::Metric::INNER_PRODUCT) {
|
||||
|
@ -1121,14 +1427,12 @@ namespace diskann {
|
|||
// range search returns results of all neighbors within distance of range.
|
||||
// indices and distances need to be pre-allocated of size l_search and the
|
||||
// return value is the number of matching hits.
|
||||
template<typename T>
|
||||
_u32 PQFlashIndex<T>::range_search(const T *query1, const double range,
|
||||
const _u64 min_l_search,
|
||||
const _u64 max_l_search,
|
||||
std::vector<_u64> &indices,
|
||||
std::vector<float> &distances,
|
||||
const _u64 min_beam_width,
|
||||
QueryStats *stats) {
|
||||
template<typename T, typename LabelT>
|
||||
_u32 PQFlashIndex<T, LabelT>::range_search(
|
||||
const T *query1, const double range, const _u64 min_l_search,
|
||||
const _u64 max_l_search, std::vector<_u64> &indices,
|
||||
std::vector<float> &distances, const _u64 min_beam_width,
|
||||
QueryStats *stats) {
|
||||
_u32 res_count = 0;
|
||||
|
||||
bool stop_flag = false;
|
||||
|
@ -1162,23 +1466,23 @@ namespace diskann {
|
|||
return res_count;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
_u64 PQFlashIndex<T>::get_data_dim() {
|
||||
template<typename T, typename LabelT>
|
||||
_u64 PQFlashIndex<T, LabelT>::get_data_dim() {
|
||||
return data_dim;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
diskann::Metric PQFlashIndex<T>::get_metric() {
|
||||
template<typename T, typename LabelT>
|
||||
diskann::Metric PQFlashIndex<T, LabelT>::get_metric() {
|
||||
return this->metric;
|
||||
}
|
||||
|
||||
#ifdef EXEC_ENV_OLS
|
||||
template<typename T>
|
||||
char *PQFlashIndex<T>::getHeaderBytes() {
|
||||
template<typename T, typename LabelT>
|
||||
char *PQFlashIndex<T, LabelT>::getHeaderBytes() {
|
||||
IOContext &ctx = reader->get_ctx();
|
||||
AlignedRead readReq;
|
||||
readReq.buf = new char[PQFlashIndex<T>::HEADER_SIZE];
|
||||
readReq.len = PQFlashIndex<T>::HEADER_SIZE;
|
||||
readReq.buf = new char[PQFlashIndex<T, LabelT>::HEADER_SIZE];
|
||||
readReq.len = PQFlashIndex<T, LabelT>::HEADER_SIZE;
|
||||
readReq.offset = 0;
|
||||
|
||||
std::vector<AlignedRead> readReqs;
|
||||
|
@ -1194,5 +1498,8 @@ namespace diskann {
|
|||
template class PQFlashIndex<_u8>;
|
||||
template class PQFlashIndex<_s8>;
|
||||
template class PQFlashIndex<float>;
|
||||
template class PQFlashIndex<_u8, uint16_t>;
|
||||
template class PQFlashIndex<_s8, uint16_t>;
|
||||
template class PQFlashIndex<float, uint16_t>;
|
||||
|
||||
} // namespace diskann
|
||||
|
|
|
@ -87,15 +87,11 @@ namespace diskann {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
InMemorySearch<T>::InMemorySearch(
|
||||
const std::string& baseFile,
|
||||
const std::string& indexFile,
|
||||
const std::string& tagsFile,
|
||||
Metric m,
|
||||
uint32_t num_threads,
|
||||
uint32_t search_l
|
||||
): BaseSearch(tagsFile) {
|
||||
|
||||
InMemorySearch<T>::InMemorySearch(const std::string& baseFile,
|
||||
const std::string& indexFile,
|
||||
const std::string& tagsFile, Metric m,
|
||||
uint32_t num_threads, uint32_t search_l)
|
||||
: BaseSearch(tagsFile) {
|
||||
size_t dimensions, total_points = 0;
|
||||
diskann::get_bin_metadata(baseFile, total_points, dimensions);
|
||||
_index = std::unique_ptr<diskann::Index<T>>(
|
||||
|
@ -136,9 +132,9 @@ namespace diskann {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
PQFlashSearch<T>::PQFlashSearch(const std::string& indexPrefix,
|
||||
const unsigned num_nodes_to_cache,
|
||||
const unsigned num_threads,
|
||||
PQFlashSearch<T>::PQFlashSearch(const std::string& indexPrefix,
|
||||
const unsigned num_nodes_to_cache,
|
||||
const unsigned num_threads,
|
||||
const std::string& tagsFile, Metric m)
|
||||
: BaseSearch(tagsFile) {
|
||||
#ifdef _WINDOWS
|
||||
|
@ -162,7 +158,8 @@ namespace diskann {
|
|||
int res = _index->load(num_threads, index_prefix_path.c_str());
|
||||
|
||||
if (res != 0) {
|
||||
std::cerr << "Unable to load index. Status code: " << res << "." << std::endl;
|
||||
std::cerr << "Unable to load index. Status code: " << res << "."
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> node_list;
|
||||
|
@ -174,13 +171,13 @@ namespace diskann {
|
|||
}
|
||||
|
||||
template<typename T>
|
||||
SearchResult PQFlashSearch<T>::search(const T* query,
|
||||
SearchResult PQFlashSearch<T>::search(const T* query,
|
||||
const unsigned int dimensions,
|
||||
const unsigned int K,
|
||||
const unsigned int Ls) {
|
||||
_u64* indices_u64 = new _u64[K];
|
||||
_u64* indices_u64 = new _u64[K];
|
||||
unsigned* indices = new unsigned[K];
|
||||
float* distances = new float[K];
|
||||
float* distances = new float[K];
|
||||
|
||||
auto startTime = std::chrono::high_resolution_clock::now();
|
||||
_index->cached_beam_search(query, K, Ls, indices_u64, distances, DEFAULT_W);
|
||||
|
|
|
@ -62,7 +62,7 @@ namespace diskann {
|
|||
std::vector<size_t> pos(numsearchers, 0);
|
||||
|
||||
for (size_t k = 0; k < K; ++k) {
|
||||
float best_distance = std::numeric_limits<float>::max();
|
||||
float best_distance = std::numeric_limits<float>::max();
|
||||
unsigned best_partition = 0;
|
||||
|
||||
for (size_t i = 0; i < numsearchers; ++i) {
|
||||
|
@ -71,20 +71,23 @@ namespace diskann {
|
|||
best_partition = i;
|
||||
}
|
||||
}
|
||||
best_distances[k] = best_distance;
|
||||
best_indices[k] = results[best_partition].get_indices()[pos[best_partition]];
|
||||
best_partitions[k] = best_partition;
|
||||
if (results[best_partition].tags_enabled())
|
||||
best_tags[k] = results[best_partition].get_tags()[pos[best_partition]];
|
||||
std::cout << best_partition << " " << pos[best_partition] << std::endl;
|
||||
pos[best_partition]++;
|
||||
best_distances[k] = best_distance;
|
||||
best_indices[k] =
|
||||
results[best_partition].get_indices()[pos[best_partition]];
|
||||
best_partitions[k] = best_partition;
|
||||
if (results[best_partition].tags_enabled())
|
||||
best_tags[k] =
|
||||
results[best_partition].get_tags()[pos[best_partition]];
|
||||
std::cout << best_partition << " " << pos[best_partition] << std::endl;
|
||||
pos[best_partition]++;
|
||||
}
|
||||
|
||||
unsigned int total_time = 0;
|
||||
for (size_t i = 0; i < numsearchers; ++i)
|
||||
total_time += results[i].get_time();
|
||||
diskann::SearchResult result = SearchResult(
|
||||
K, total_time, best_indices, best_distances, best_tags, best_partitions);
|
||||
diskann::SearchResult result =
|
||||
SearchResult(K, total_time, best_indices, best_distances, best_tags,
|
||||
best_partitions);
|
||||
|
||||
delete[] best_indices;
|
||||
delete[] best_distances;
|
||||
|
@ -101,8 +104,8 @@ namespace diskann {
|
|||
void Server::handle_post(web::http::http_request message) {
|
||||
message.extract_string(true)
|
||||
.then([=](utility::string_t body) {
|
||||
int64_t queryId = -1;
|
||||
unsigned int K = 0;
|
||||
int64_t queryId = -1;
|
||||
unsigned int K = 0;
|
||||
try {
|
||||
T* queryVector = nullptr;
|
||||
unsigned int dimensions = 0;
|
||||
|
@ -113,7 +116,8 @@ namespace diskann {
|
|||
std::vector<diskann::SearchResult> results;
|
||||
|
||||
for (auto& searcher : _multi_searcher)
|
||||
results.push_back(searcher->search(queryVector, dimensions, (unsigned int) K, Ls));
|
||||
results.push_back(searcher->search(queryVector, dimensions,
|
||||
(unsigned int) K, Ls));
|
||||
diskann::SearchResult result = aggregate_results(K, results);
|
||||
diskann::aligned_free(queryVector);
|
||||
web::json::value response = prepareResponse(queryId, K);
|
||||
|
@ -139,13 +143,17 @@ namespace diskann {
|
|||
return std::make_pair(web::http::status_codes::InternalError,
|
||||
response);
|
||||
} catch (...) {
|
||||
std::cerr << "Uncaught exception while processing query: " << queryId;
|
||||
std::cerr << "Uncaught exception while processing query: "
|
||||
<< queryId;
|
||||
web::json::value response = prepareResponse(queryId, K);
|
||||
response[ERROR_MESSAGE_KEY] = web::json::value::string(UNKNOWN_ERROR);
|
||||
return std::make_pair(web::http::status_codes::InternalError, response);
|
||||
response[ERROR_MESSAGE_KEY] =
|
||||
web::json::value::string(UNKNOWN_ERROR);
|
||||
return std::make_pair(web::http::status_codes::InternalError,
|
||||
response);
|
||||
}
|
||||
})
|
||||
.then([=](std::pair<short unsigned int, web::json::value> response_status) {
|
||||
.then([=](std::pair<short unsigned int, web::json::value>
|
||||
response_status) {
|
||||
try {
|
||||
message.reply(response_status.first, response_status.second).wait();
|
||||
} catch (const std::exception& ex) {
|
||||
|
@ -180,7 +188,8 @@ namespace diskann {
|
|||
|
||||
if (k <= 0 || k > Ls) {
|
||||
throw new std::invalid_argument(
|
||||
"Num of expected NN (k) must be greater than zero and less than or equal to Ls.");
|
||||
"Num of expected NN (k) must be greater than zero and less than or "
|
||||
"equal to Ls.");
|
||||
}
|
||||
if (queryArr.size() == 0) {
|
||||
throw new std::invalid_argument("Query vector has zero elements.");
|
||||
|
|
|
@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 14)
|
|||
add_executable(build_memory_index build_memory_index.cpp)
|
||||
target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(build_stitched_index build_stitched_index.cpp)
|
||||
target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
add_executable(search_memory_index search_memory_index.cpp)
|
||||
target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options)
|
||||
|
||||
|
|
|
@ -13,11 +13,12 @@
|
|||
namespace po = boost::program_options;
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
||||
unsigned num_threads, R, L, disk_PQ, build_PQ;
|
||||
float B, M;
|
||||
bool append_reorder_data = false;
|
||||
bool use_opq = false;
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file,
|
||||
universal_label, label_type;
|
||||
unsigned num_threads, R, L, disk_PQ, build_PQ, Lf, filter_threshold;
|
||||
float B, M;
|
||||
bool append_reorder_data = false;
|
||||
bool use_opq = false;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
|
@ -60,9 +61,34 @@ int main(int argc, char** argv) {
|
|||
desc.add_options()(
|
||||
"build_PQ_bytes", po::value<uint32_t>(&build_PQ)->default_value(0),
|
||||
"Number of PQ bytes to build the index; 0 for full precision build");
|
||||
|
||||
desc.add_options()("use_opq", po::bool_switch()->default_value(false),
|
||||
"Use Optimized Product Quantization (OPQ).");
|
||||
desc.add_options()(
|
||||
"label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index build ."
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
desc.add_options()(
|
||||
"universal_label",
|
||||
po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, Use only in conjuction with label file for filtered "
|
||||
"index build. If a graph node has all the labels against it, we can "
|
||||
"assign a special universal filter to the point instead of comma "
|
||||
"separated filters for that point");
|
||||
desc.add_options()("filtered_Lbuild,Lf",
|
||||
po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
desc.add_options()(
|
||||
"filter_threshold,F",
|
||||
po::value<uint32_t>(&filter_threshold)->default_value(0),
|
||||
"Threshold to break up the existing nodes to generate new graph "
|
||||
"internally where each node has a maximum F labels.");
|
||||
desc.add_options()(
|
||||
"label_type",
|
||||
po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
|
@ -80,6 +106,11 @@ int main(int argc, char** argv) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
bool use_filters = false;
|
||||
if (label_file != "") {
|
||||
use_filters = true;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
|
@ -116,21 +147,46 @@ int main(int argc, char** argv) {
|
|||
std::string(std::to_string(build_PQ));
|
||||
|
||||
try {
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(data_path.c_str(),
|
||||
index_path_prefix.c_str(),
|
||||
params.c_str(), metric, use_opq);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric,
|
||||
use_opq);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float>(data_path.c_str(),
|
||||
index_path_prefix.c_str(),
|
||||
params.c_str(), metric, use_opq);
|
||||
else {
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
if (label_file != "" && label_type == "ushort") {
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float, uint16_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else {
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
if (data_type == std::string("int8"))
|
||||
return diskann::build_disk_index<int8_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return diskann::build_disk_index<uint8_t>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return diskann::build_disk_index<float>(
|
||||
data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
|
||||
metric, use_opq, use_filters, label_file, universal_label,
|
||||
filter_threshold, Lf);
|
||||
else {
|
||||
diskann::cerr << "Error. Unsupported data type" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
|
|
|
@ -20,44 +20,62 @@
|
|||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template<typename T, typename TagT = uint32_t>
|
||||
template<typename T, typename TagT = uint32_t, typename LabelT = uint32_t>
|
||||
int build_in_memory_index(const diskann::Metric& metric,
|
||||
const std::string& data_path, const unsigned R,
|
||||
const unsigned L, const float alpha,
|
||||
const std::string& save_path,
|
||||
const unsigned num_threads, const bool use_pq_build,
|
||||
const size_t num_pq_bytes, const bool use_opq) {
|
||||
const size_t num_pq_bytes, const bool use_opq,
|
||||
const std::string& label_file,
|
||||
const std::string& universal_label, const _u32 Lf) {
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("R", R);
|
||||
paras.Set<unsigned>("L", L);
|
||||
paras.Set<unsigned>("Lf", Lf);
|
||||
paras.Set<unsigned>(
|
||||
"C", 750); // maximum candidate set size during pruning procedure
|
||||
paras.Set<float>("alpha", alpha);
|
||||
paras.Set<bool>("saturate_graph", 0);
|
||||
paras.Set<unsigned>("num_threads", num_threads);
|
||||
std::string labels_file_to_use = save_path + "_label_formatted.txt";
|
||||
std::string mem_labels_int_map_file = save_path + "_labels_map.txt";
|
||||
|
||||
_u64 data_num, data_dim;
|
||||
diskann::get_bin_metadata(data_path, data_num, data_dim);
|
||||
|
||||
diskann::Index<T, TagT> index(metric, data_dim, data_num, false, false, false,
|
||||
use_pq_build, num_pq_bytes, use_opq);
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
index.build(data_path.c_str(), data_num, paras);
|
||||
|
||||
diskann::Index<T, TagT, LabelT> index(metric, data_dim, data_num, false,
|
||||
false, false, use_pq_build,
|
||||
num_pq_bytes, use_opq);
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
if (label_file == "") {
|
||||
index.build(data_path.c_str(), data_num, paras);
|
||||
} else {
|
||||
convert_labels_string_to_int(label_file, labels_file_to_use,
|
||||
mem_labels_int_map_file, universal_label);
|
||||
if (universal_label != "") {
|
||||
LabelT unv_label_as_num = std::stoul(universal_label);
|
||||
index.set_universal_label(unv_label_as_num);
|
||||
}
|
||||
index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num,
|
||||
paras);
|
||||
}
|
||||
std::chrono::duration<double> diff =
|
||||
std::chrono::high_resolution_clock::now() - s;
|
||||
|
||||
std::cout << "Indexing time: " << diff.count() << "\n";
|
||||
index.save(save_path.c_str());
|
||||
|
||||
if (label_file != "")
|
||||
std::remove(labels_file_to_use.c_str());
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix;
|
||||
unsigned num_threads, R, L, build_PQ_bytes;
|
||||
float alpha;
|
||||
bool use_pq_build, use_opq;
|
||||
std::string data_type, dist_fn, data_path, index_path_prefix, label_file,
|
||||
universal_label, label_type;
|
||||
unsigned num_threads, R, L, Lf, build_PQ_bytes;
|
||||
float alpha;
|
||||
bool use_pq_build, use_opq;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
|
@ -89,12 +107,31 @@ int main(int argc, char** argv) {
|
|||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()(
|
||||
"build_PQ_bytes", po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
||||
"build_PQ_bytes",
|
||||
po::value<uint32_t>(&build_PQ_bytes)->default_value(0),
|
||||
"Number of PQ bytes to build the index; 0 for full precision build");
|
||||
desc.add_options()(
|
||||
"use_opq", po::bool_switch()->default_value(false),
|
||||
"Set true for OPQ compression while using PQ distance comparisons for "
|
||||
"building the index, and false for PQ compression");
|
||||
desc.add_options()(
|
||||
"label_file", po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input label file in txt format for Filtered Index search. "
|
||||
"The file should contain comma separated filters for each node "
|
||||
"with each line corresponding to a graph node");
|
||||
desc.add_options()(
|
||||
"universal_label",
|
||||
po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with labels_file");
|
||||
desc.add_options()("FilteredLbuild,Lf",
|
||||
po::value<uint32_t>(&Lf)->default_value(0),
|
||||
"Build complexity for filtered points, higher value "
|
||||
"results in better graphs");
|
||||
desc.add_options()(
|
||||
"label_type",
|
||||
po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
|
@ -128,22 +165,48 @@ int main(int argc, char** argv) {
|
|||
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L
|
||||
<< " alpha: " << alpha << " #threads: " << num_threads
|
||||
<< std::endl;
|
||||
if (data_type == std::string("int8"))
|
||||
return build_in_memory_index<int8_t>(metric, data_path, R, L, alpha,
|
||||
index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return build_in_memory_index<uint8_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq);
|
||||
else if (data_type == std::string("float"))
|
||||
return build_in_memory_index<float>(metric, data_path, R, L, alpha,
|
||||
index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq);
|
||||
else {
|
||||
std::cout << "Unsupported type. Use one of int8, uint8 or float."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
if (label_file != "" && label_type == "ushort") {
|
||||
if (data_type == std::string("int8"))
|
||||
return build_in_memory_index<int8_t, uint32_t, uint16_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return build_in_memory_index<uint8_t, uint32_t, uint16_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return build_in_memory_index<float, uint32_t, uint16_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else {
|
||||
std::cout << "Unsupported type. Use one of int8, uint8 or float."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
if (data_type == std::string("int8"))
|
||||
return build_in_memory_index<int8_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return build_in_memory_index<uint8_t>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else if (data_type == std::string("float"))
|
||||
return build_in_memory_index<float>(
|
||||
metric, data_path, R, L, alpha, index_path_prefix, num_threads,
|
||||
use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label,
|
||||
Lf);
|
||||
else {
|
||||
std::cout << "Unsupported type. Use one of int8, uint8 or float."
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
|
|
|
@ -0,0 +1,855 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <boost/program_options.hpp>
|
||||
#include <chrono>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include <omp.h>
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/uio.h>
|
||||
#endif
|
||||
|
||||
#include "index.h"
|
||||
#include "memory_mapper.h"
|
||||
#include "parameters.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
// macros
|
||||
#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"
|
||||
#define PBWIDTH 60
|
||||
|
||||
// custom types (for readability)
|
||||
typedef tsl::robin_set<std::string> label_set;
|
||||
typedef std::string path;
|
||||
|
||||
// structs for returning multiple items from a function
|
||||
typedef std::tuple<std::vector<label_set>, tsl::robin_map<std::string, _u32>,
|
||||
label_set>
|
||||
parse_label_file_return_values;
|
||||
typedef std::tuple<std::vector<std::vector<_u32>>, _u64>
|
||||
load_label_index_return_values;
|
||||
typedef std::tuple<std::vector<std::vector<_u32>>, _u64>
|
||||
stitch_indices_return_values;
|
||||
|
||||
/*
|
||||
* Inline function to display progress bar.
|
||||
*/
|
||||
inline void print_progress(double percentage) {
|
||||
int val = (int) (percentage * 100);
|
||||
int lpad = (int) (percentage * PBWIDTH);
|
||||
int rpad = PBWIDTH - lpad;
|
||||
printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, "");
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
/*
|
||||
* Inline function to generate a random integer in a range.
|
||||
*/
|
||||
inline size_t random(size_t range_from, size_t range_to) {
|
||||
std::random_device rand_dev;
|
||||
std::mt19937 generator(rand_dev());
|
||||
std::uniform_int_distribution<size_t> distr(range_from, range_to);
|
||||
return distr(generator);
|
||||
}
|
||||
|
||||
/*
|
||||
* function to handle command line parsing.
|
||||
*
|
||||
* Arguments are merely the inputs from the command line.
|
||||
*/
|
||||
void handle_args(int argc, char **argv, std::string &data_type,
|
||||
path &input_data_path, path &final_index_path_prefix,
|
||||
path &label_data_path, std::string &universal_label,
|
||||
unsigned &num_threads, unsigned &R, unsigned &L,
|
||||
unsigned &stitched_R, float &alpha) {
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("data_path",
|
||||
po::value<path>(&input_data_path)->required(),
|
||||
"Input data file in bin format");
|
||||
desc.add_options()("index_path_prefix",
|
||||
po::value<path>(&final_index_path_prefix)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()("max_degree,R",
|
||||
po::value<uint32_t>(&R)->default_value(64),
|
||||
"Maximum graph degree");
|
||||
desc.add_options()(
|
||||
"Lbuild,L", po::value<uint32_t>(&L)->default_value(100),
|
||||
"Build complexity, higher value results in better graphs");
|
||||
desc.add_options()("stitched_R",
|
||||
po::value<uint32_t>(&stitched_R)->default_value(100),
|
||||
"Degree to prune final graph down to");
|
||||
desc.add_options()(
|
||||
"alpha", po::value<float>(&alpha)->default_value(1.2f),
|
||||
"alpha controls density and diameter of graph, set 1 for sparse graph, "
|
||||
"1.2 or 1.4 for denser graphs with lower diameter");
|
||||
desc.add_options()(
|
||||
"num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("label_file",
|
||||
po::value<path>(&label_data_path)->default_value(""),
|
||||
"Input label file in txt format if present");
|
||||
desc.add_options()(
|
||||
"universal_label",
|
||||
po::value<std::string>(&universal_label)->default_value(""),
|
||||
"If a point comes with the specified universal label (and only the "
|
||||
"univ. "
|
||||
"label), then the point is considered to have every possible label");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
exit(0);
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception &ex) {
|
||||
std::cerr << ex.what() << '\n';
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Parses the label datafile, which has comma-separated labels on
|
||||
* each line. Line i corresponds to point id i.
|
||||
*
|
||||
* Returns three objects via std::tuple:
|
||||
* 1. map: key is point id, value is vector of labels said point has
|
||||
* 2. map: key is label, value is number of points with the label
|
||||
* 3. the label universe as a set
|
||||
*/
|
||||
parse_label_file_return_values parse_label_file(path label_data_path,
|
||||
std::string universal_label) {
|
||||
std::ifstream label_data_stream(label_data_path);
|
||||
std::string line, token;
|
||||
unsigned line_cnt = 0;
|
||||
|
||||
// allows us to reserve space for the points_to_labels vector
|
||||
while (std::getline(label_data_stream, line))
|
||||
line_cnt++;
|
||||
label_data_stream.clear();
|
||||
label_data_stream.seekg(0, std::ios::beg);
|
||||
|
||||
// values to return
|
||||
std::vector<label_set> point_ids_to_labels(line_cnt);
|
||||
tsl::robin_map<std::string, _u32> labels_to_number_of_points;
|
||||
label_set all_labels;
|
||||
|
||||
std::vector<_u32> points_with_universal_label;
|
||||
line_cnt = 0;
|
||||
while (std::getline(label_data_stream, line)) {
|
||||
std::istringstream current_labels_comma_separated(line);
|
||||
label_set current_labels;
|
||||
|
||||
// get point id
|
||||
_u32 point_id = line_cnt;
|
||||
|
||||
// parse comma separated labels
|
||||
bool current_universal_label_check = false;
|
||||
while (getline(current_labels_comma_separated, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
|
||||
// if token is empty, there's no labels for the point
|
||||
if (token == universal_label) {
|
||||
points_with_universal_label.push_back(point_id);
|
||||
current_universal_label_check = true;
|
||||
} else {
|
||||
all_labels.insert(token);
|
||||
current_labels.insert(token);
|
||||
labels_to_number_of_points[token]++;
|
||||
}
|
||||
}
|
||||
|
||||
if (current_labels.size() <= 0 && !current_universal_label_check) {
|
||||
std::cerr << "Error: " << point_id << " has no labels." << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
point_ids_to_labels[point_id] = current_labels;
|
||||
line_cnt++;
|
||||
}
|
||||
|
||||
// for every point with universal label, set its label set to all labels
|
||||
// also, increment the count for number of points a label has
|
||||
for (const auto &point_id : points_with_universal_label) {
|
||||
point_ids_to_labels[point_id] = all_labels;
|
||||
for (const auto &lbl : all_labels)
|
||||
labels_to_number_of_points[lbl]++;
|
||||
}
|
||||
|
||||
std::cout << "Identified " << all_labels.size() << " distinct label(s) for "
|
||||
<< point_ids_to_labels.size() << " points\n"
|
||||
<< std::endl;
|
||||
|
||||
return std::make_tuple(point_ids_to_labels, labels_to_number_of_points,
|
||||
all_labels);
|
||||
}
|
||||
|
||||
/*
|
||||
* For each label, generates a file containing all vectors that have said label.
|
||||
* Also copies data from original bin file to new dimension-aligned file.
|
||||
*
|
||||
* Utilizes POSIX functions mmap and writev in order to minimize memory
|
||||
* overhead, so we include an STL version as well.
|
||||
*
|
||||
* Each data file is saved under the following format:
|
||||
* input_data_path + "_" + label
|
||||
*/
|
||||
template<typename T>
|
||||
tsl::robin_map<std::string, std::vector<_u32>>
|
||||
generate_label_specific_vector_files(
|
||||
path input_data_path,
|
||||
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
|
||||
std::vector<label_set> point_ids_to_labels, label_set all_labels) {
|
||||
auto file_writing_timer = std::chrono::high_resolution_clock::now();
|
||||
diskann::MemoryMapper input_data(input_data_path);
|
||||
char *input_start = input_data.getBuf();
|
||||
|
||||
_u32 number_of_points, dimension;
|
||||
std::memcpy(&number_of_points, input_start, sizeof(_u32));
|
||||
std::memcpy(&dimension, input_start + sizeof(_u32), sizeof(_u32));
|
||||
const _u32 VECTOR_SIZE = dimension * sizeof(T);
|
||||
const size_t METADATA = 2 * sizeof(_u32);
|
||||
if (number_of_points != point_ids_to_labels.size()) {
|
||||
std::cerr << "Error: number of points in labels file and data file differ."
|
||||
<< std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
tsl::robin_map<std::string, iovec *> label_to_iovec_map;
|
||||
tsl::robin_map<std::string, _u32> label_to_curr_iovec;
|
||||
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id;
|
||||
|
||||
// setup iovec list for each label
|
||||
for (const auto &lbl : all_labels) {
|
||||
iovec *label_iovecs =
|
||||
(iovec *) malloc(labels_to_number_of_points[lbl] * sizeof(iovec));
|
||||
if (label_iovecs == nullptr) {
|
||||
throw;
|
||||
}
|
||||
label_to_iovec_map[lbl] = label_iovecs;
|
||||
label_to_curr_iovec[lbl] = 0;
|
||||
label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]);
|
||||
}
|
||||
|
||||
// each point added to corresponding per-label iovec list
|
||||
for (_u32 point_id = 0; point_id < number_of_points; point_id++) {
|
||||
char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id);
|
||||
iovec curr_iovec;
|
||||
|
||||
curr_iovec.iov_base = curr_point;
|
||||
curr_iovec.iov_len = VECTOR_SIZE;
|
||||
for (const auto &lbl : point_ids_to_labels[point_id]) {
|
||||
*(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec;
|
||||
label_to_curr_iovec[lbl]++;
|
||||
label_id_to_orig_id[lbl].push_back(point_id);
|
||||
}
|
||||
}
|
||||
|
||||
// write each label iovec to resp. file
|
||||
for (const auto &lbl : all_labels) {
|
||||
int label_input_data_fd;
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
_u32 curr_num_pts = labels_to_number_of_points[lbl];
|
||||
|
||||
label_input_data_fd =
|
||||
open(curr_label_input_data_path.c_str(),
|
||||
O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t) 0644);
|
||||
if (label_input_data_fd == -1)
|
||||
throw;
|
||||
|
||||
// write metadata
|
||||
_u32 metadata[2] = {curr_num_pts, dimension};
|
||||
int return_value = write(label_input_data_fd, metadata, sizeof(_u32) * 2);
|
||||
if (return_value == -1) {
|
||||
throw;
|
||||
}
|
||||
|
||||
// limits on number of iovec structs per writev means we need to perform
|
||||
// multiple writevs
|
||||
size_t i = 0;
|
||||
while (curr_num_pts > IOV_MAX) {
|
||||
return_value = writev(label_input_data_fd,
|
||||
(label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX);
|
||||
if (return_value == -1) {
|
||||
close(label_input_data_fd);
|
||||
throw;
|
||||
}
|
||||
curr_num_pts -= IOV_MAX;
|
||||
i += 1;
|
||||
}
|
||||
return_value =
|
||||
writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)),
|
||||
curr_num_pts);
|
||||
if (return_value == -1) {
|
||||
close(label_input_data_fd);
|
||||
throw;
|
||||
}
|
||||
|
||||
free(label_to_iovec_map[lbl]);
|
||||
close(label_input_data_fd);
|
||||
}
|
||||
|
||||
std::chrono::duration<double> file_writing_time =
|
||||
std::chrono::high_resolution_clock::now() - file_writing_timer;
|
||||
std::cout << "generated " << all_labels.size()
|
||||
<< " label-specific vector files for index building in time "
|
||||
<< file_writing_time.count() << "\n"
|
||||
<< std::endl;
|
||||
|
||||
return label_id_to_orig_id;
|
||||
}
|
||||
|
||||
// for use on systems without writev (i.e. Windows)
|
||||
template<typename T>
|
||||
tsl::robin_map<std::string, std::vector<_u32>>
|
||||
generate_label_specific_vector_files_compat(
|
||||
path input_data_path,
|
||||
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
|
||||
std::vector<label_set> point_ids_to_labels, label_set all_labels) {
|
||||
auto file_writing_timer = std::chrono::high_resolution_clock::now();
|
||||
std::ifstream input_data_stream(input_data_path);
|
||||
|
||||
_u32 number_of_points, dimension;
|
||||
input_data_stream.read((char *) &number_of_points, sizeof(_u32));
|
||||
input_data_stream.read((char *) &dimension, sizeof(_u32));
|
||||
const _u32 VECTOR_SIZE = dimension * sizeof(T);
|
||||
if (number_of_points != point_ids_to_labels.size()) {
|
||||
std::cerr << "Error: number of points in labels file and data file differ."
|
||||
<< std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
tsl::robin_map<std::string, char *> labels_to_vectors;
|
||||
tsl::robin_map<std::string, _u32> labels_to_curr_vector;
|
||||
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id;
|
||||
|
||||
for (const auto &lbl : all_labels) {
|
||||
_u32 number_of_label_pts = labels_to_number_of_points[lbl];
|
||||
char *vectors = (char *) malloc(number_of_label_pts * VECTOR_SIZE);
|
||||
if (vectors == nullptr) {
|
||||
throw;
|
||||
}
|
||||
labels_to_vectors[lbl] = vectors;
|
||||
labels_to_curr_vector[lbl] = 0;
|
||||
label_id_to_orig_id[lbl].reserve(number_of_label_pts);
|
||||
}
|
||||
|
||||
for (_u32 point_id = 0; point_id < number_of_points; point_id++) {
|
||||
char *curr_vector = (char *) malloc(VECTOR_SIZE);
|
||||
input_data_stream.read(curr_vector, VECTOR_SIZE);
|
||||
for (const auto &lbl : point_ids_to_labels[point_id]) {
|
||||
char *curr_label_vector_ptr =
|
||||
labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE);
|
||||
memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE);
|
||||
labels_to_curr_vector[lbl]++;
|
||||
label_id_to_orig_id[lbl].push_back(point_id);
|
||||
}
|
||||
free(curr_vector);
|
||||
}
|
||||
|
||||
for (const auto &lbl : all_labels) {
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
_u32 number_of_label_pts = labels_to_number_of_points[lbl];
|
||||
|
||||
std::ofstream label_file_stream;
|
||||
label_file_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
label_file_stream.open(curr_label_input_data_path, std::ios_base::binary);
|
||||
label_file_stream.write((char *) &number_of_label_pts, sizeof(_u32));
|
||||
label_file_stream.write((char *) &dimension, sizeof(_u32));
|
||||
label_file_stream.write((char *) labels_to_vectors[lbl],
|
||||
number_of_label_pts * VECTOR_SIZE);
|
||||
|
||||
label_file_stream.close();
|
||||
free(labels_to_vectors[lbl]);
|
||||
}
|
||||
input_data_stream.close();
|
||||
|
||||
std::chrono::duration<double> file_writing_time =
|
||||
std::chrono::high_resolution_clock::now() - file_writing_timer;
|
||||
std::cout << "generated " << all_labels.size()
|
||||
<< " label-specific vector files for index building in time "
|
||||
<< file_writing_time.count() << "\n"
|
||||
<< std::endl;
|
||||
|
||||
return label_id_to_orig_id;
|
||||
}
|
||||
|
||||
/*
|
||||
* Using passed in parameters and files generated from step 3,
|
||||
* builds a vanilla diskANN index for each label.
|
||||
*
|
||||
* Each index is saved under the following path:
|
||||
* final_index_path_prefix + "_" + label
|
||||
*/
|
||||
template<typename T>
|
||||
void generate_label_indices(path input_data_path, path final_index_path_prefix,
|
||||
label_set all_labels, unsigned R, unsigned L,
|
||||
float alpha, unsigned num_threads) {
|
||||
diskann::Parameters label_index_build_parameters;
|
||||
label_index_build_parameters.Set<unsigned>("R", R);
|
||||
label_index_build_parameters.Set<unsigned>("L", L);
|
||||
label_index_build_parameters.Set<unsigned>("C", 750);
|
||||
label_index_build_parameters.Set<unsigned>("Lf", 0);
|
||||
label_index_build_parameters.Set<bool>("saturate_graph", 0);
|
||||
label_index_build_parameters.Set<float>("alpha", alpha);
|
||||
label_index_build_parameters.Set<unsigned>("num_threads", num_threads);
|
||||
|
||||
std::cout << "Generating indices per label..." << std::endl;
|
||||
// for each label, build an index on resp. points
|
||||
double total_indexing_time = 0.0, indexing_percentage = 0.0;
|
||||
std::cout.setstate(std::ios_base::failbit);
|
||||
diskann::cout.setstate(std::ios_base::failbit);
|
||||
for (const auto &lbl : all_labels) {
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
|
||||
size_t number_of_label_points, dimension;
|
||||
diskann::get_bin_metadata(curr_label_input_data_path,
|
||||
number_of_label_points, dimension);
|
||||
diskann::Index<T> index(diskann::Metric::L2, dimension,
|
||||
number_of_label_points, false, false);
|
||||
|
||||
auto index_build_timer = std::chrono::high_resolution_clock::now();
|
||||
index.build(curr_label_input_data_path.c_str(), number_of_label_points,
|
||||
label_index_build_parameters);
|
||||
std::chrono::duration<double> current_indexing_time =
|
||||
std::chrono::high_resolution_clock::now() - index_build_timer;
|
||||
|
||||
total_indexing_time += current_indexing_time.count();
|
||||
indexing_percentage += (1 / (double) all_labels.size());
|
||||
print_progress(indexing_percentage);
|
||||
|
||||
index.save(curr_label_index_path.c_str());
|
||||
}
|
||||
std::cout.clear();
|
||||
diskann::cout.clear();
|
||||
|
||||
std::cout << "\nDone. Generated per-label indices in " << total_indexing_time
|
||||
<< " seconds\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Manually loads a graph index in from a given file.
|
||||
*
|
||||
* Returns both the graph index and the size of the file in bytes.
|
||||
*/
|
||||
load_label_index_return_values load_label_index(path label_index_path,
|
||||
_u32 label_number_of_points) {
|
||||
std::ifstream label_index_stream;
|
||||
label_index_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
label_index_stream.open(label_index_path, std::ios::binary);
|
||||
|
||||
_u64 index_file_size, index_num_frozen_points;
|
||||
_u32 index_max_observed_degree, index_entry_point;
|
||||
const size_t INDEX_METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
|
||||
label_index_stream.read((char *) &index_file_size, sizeof(_u64));
|
||||
label_index_stream.read((char *) &index_max_observed_degree, sizeof(_u32));
|
||||
label_index_stream.read((char *) &index_entry_point, sizeof(_u32));
|
||||
label_index_stream.read((char *) &index_num_frozen_points, sizeof(_u64));
|
||||
size_t bytes_read = INDEX_METADATA;
|
||||
|
||||
std::vector<std::vector<_u32>> label_index(label_number_of_points);
|
||||
_u32 nodes_read = 0;
|
||||
while (bytes_read != index_file_size) {
|
||||
_u32 current_node_num_neighbors;
|
||||
label_index_stream.read((char *) ¤t_node_num_neighbors, sizeof(_u32));
|
||||
nodes_read++;
|
||||
|
||||
std::vector<_u32> current_node_neighbors(current_node_num_neighbors);
|
||||
label_index_stream.read((char *) current_node_neighbors.data(),
|
||||
current_node_num_neighbors * sizeof(_u32));
|
||||
label_index[nodes_read - 1].swap(current_node_neighbors);
|
||||
bytes_read += sizeof(_u32) * (current_node_num_neighbors + 1);
|
||||
}
|
||||
|
||||
return std::make_tuple(label_index, index_file_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* Custom index save to write the in-memory index to disk.
|
||||
* Also writes required files for diskANN API -
|
||||
* 1. labels_to_medoids
|
||||
* 2. universal_label
|
||||
* 3. data (redundant for static indices)
|
||||
* 4. labels (redundant for static indices)
|
||||
*/
|
||||
void save_full_index(path final_index_path_prefix, path input_data_path,
|
||||
_u64 final_index_size,
|
||||
std::vector<std::vector<_u32>> stitched_graph,
|
||||
tsl::robin_map<std::string, _u32> entry_points,
|
||||
std::string universal_label, path label_data_path) {
|
||||
// aux. file 1
|
||||
auto saving_index_timer = std::chrono::high_resolution_clock::now();
|
||||
std::ifstream original_label_data_stream;
|
||||
original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_label_data_stream.open(label_data_path, std::ios::binary);
|
||||
std::ofstream new_label_data_stream;
|
||||
new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_label_data_stream.open(final_index_path_prefix + "_labels.txt",
|
||||
std::ios::binary);
|
||||
new_label_data_stream << original_label_data_stream.rdbuf();
|
||||
original_label_data_stream.close();
|
||||
new_label_data_stream.close();
|
||||
|
||||
// aux. file 2
|
||||
std::ifstream original_input_data_stream;
|
||||
original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
original_input_data_stream.open(input_data_path, std::ios::binary);
|
||||
std::ofstream new_input_data_stream;
|
||||
new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
new_input_data_stream.open(final_index_path_prefix + ".data",
|
||||
std::ios::binary);
|
||||
new_input_data_stream << original_input_data_stream.rdbuf();
|
||||
original_input_data_stream.close();
|
||||
new_input_data_stream.close();
|
||||
|
||||
// aux. file 3
|
||||
std::ofstream labels_to_medoids_writer;
|
||||
labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
labels_to_medoids_writer.open(final_index_path_prefix +
|
||||
"_labels_to_medoids.txt");
|
||||
for (auto iter : entry_points)
|
||||
labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl;
|
||||
labels_to_medoids_writer.close();
|
||||
|
||||
// aux. file 4 (only if we're using a universal label)
|
||||
if (universal_label != "") {
|
||||
std::ofstream universal_label_writer;
|
||||
universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
universal_label_writer.open(final_index_path_prefix +
|
||||
"_universal_label.txt");
|
||||
universal_label_writer << universal_label << std::endl;
|
||||
universal_label_writer.close();
|
||||
}
|
||||
|
||||
// main index
|
||||
_u64 index_num_frozen_points = 0, index_num_edges = 0;
|
||||
_u32 index_max_observed_degree = 0, index_entry_point = 0;
|
||||
const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
|
||||
for (auto &point_neighbors : stitched_graph) {
|
||||
index_max_observed_degree =
|
||||
std::max(index_max_observed_degree, (_u32) point_neighbors.size());
|
||||
}
|
||||
|
||||
std::ofstream stitched_graph_writer;
|
||||
stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit);
|
||||
stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary);
|
||||
|
||||
stitched_graph_writer.write((char *) &final_index_size, sizeof(_u64));
|
||||
stitched_graph_writer.write((char *) &index_max_observed_degree,
|
||||
sizeof(_u32));
|
||||
stitched_graph_writer.write((char *) &index_entry_point, sizeof(_u32));
|
||||
stitched_graph_writer.write((char *) &index_num_frozen_points, sizeof(_u64));
|
||||
|
||||
size_t bytes_written = METADATA;
|
||||
for (_u32 node_point = 0; node_point < stitched_graph.size(); node_point++) {
|
||||
_u32 current_node_num_neighbors = stitched_graph[node_point].size();
|
||||
std::vector<_u32> current_node_neighbors = stitched_graph[node_point];
|
||||
stitched_graph_writer.write((char *) ¤t_node_num_neighbors,
|
||||
sizeof(_u32));
|
||||
bytes_written += sizeof(_u32);
|
||||
for (const auto ¤t_node_neighbor : current_node_neighbors) {
|
||||
stitched_graph_writer.write((char *) ¤t_node_neighbor,
|
||||
sizeof(_u32));
|
||||
bytes_written += sizeof(_u32);
|
||||
}
|
||||
index_num_edges += current_node_num_neighbors;
|
||||
}
|
||||
|
||||
if (bytes_written != final_index_size) {
|
||||
std::cerr << "Error: written bytes does not match allocated space"
|
||||
<< std::endl;
|
||||
throw;
|
||||
}
|
||||
|
||||
stitched_graph_writer.close();
|
||||
|
||||
std::chrono::duration<double> saving_index_time =
|
||||
std::chrono::high_resolution_clock::now() - saving_index_timer;
|
||||
std::cout << "Stitched graph written in " << saving_index_time.count()
|
||||
<< " seconds" << std::endl;
|
||||
std::cout << "Stitched graph average degree: "
|
||||
<< ((float) index_num_edges) / ((float) (stitched_graph.size()))
|
||||
<< std::endl;
|
||||
std::cout << "Stitched graph max degree: " << index_max_observed_degree
|
||||
<< std::endl
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Unions the per-label graph indices together via the following policy:
|
||||
* - any two nodes can only have at most one edge between them -
|
||||
*
|
||||
* Returns the "stitched" graph and its expected file size.
|
||||
*/
|
||||
template<typename T>
|
||||
stitch_indices_return_values stitch_label_indices(
|
||||
path final_index_path_prefix, _u32 total_number_of_points,
|
||||
label_set all_labels,
|
||||
tsl::robin_map<std::string, _u32> labels_to_number_of_points,
|
||||
tsl::robin_map<std::string, _u32> &label_entry_points,
|
||||
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id_map) {
|
||||
size_t final_index_size = 0;
|
||||
std::vector<std::vector<_u32>> stitched_graph(total_number_of_points);
|
||||
|
||||
auto stitching_index_timer = std::chrono::high_resolution_clock::now();
|
||||
for (const auto &lbl : all_labels) {
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
std::vector<std::vector<_u32>> curr_label_index;
|
||||
_u64 curr_label_index_size;
|
||||
_u32 curr_label_entry_point;
|
||||
|
||||
std::tie(curr_label_index, curr_label_index_size) = load_label_index(
|
||||
curr_label_index_path, labels_to_number_of_points[lbl]);
|
||||
curr_label_entry_point = random(0, curr_label_index.size());
|
||||
label_entry_points[lbl] =
|
||||
label_id_to_orig_id_map[lbl][curr_label_entry_point];
|
||||
|
||||
for (_u32 node_point = 0; node_point < curr_label_index.size();
|
||||
node_point++) {
|
||||
_u32 original_point_id = label_id_to_orig_id_map[lbl][node_point];
|
||||
for (auto &node_neighbor : curr_label_index[node_point]) {
|
||||
_u32 original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor];
|
||||
std::vector<_u32> curr_point_neighbors =
|
||||
stitched_graph[original_point_id];
|
||||
if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(),
|
||||
original_neighbor_id) == curr_point_neighbors.end()) {
|
||||
stitched_graph[original_point_id].push_back(original_neighbor_id);
|
||||
final_index_size += sizeof(_u32);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32);
|
||||
final_index_size += (total_number_of_points * sizeof(_u32) + METADATA);
|
||||
|
||||
std::chrono::duration<double> stitching_index_time =
|
||||
std::chrono::high_resolution_clock::now() - stitching_index_timer;
|
||||
std::cout << "stitched graph generated in memory in "
|
||||
<< stitching_index_time.count() << " seconds" << std::endl;
|
||||
|
||||
return std::make_tuple(stitched_graph, final_index_size);
|
||||
}
|
||||
|
||||
/*
|
||||
* Applies the prune_neighbors function from src/index.cpp to
|
||||
* every node in the stitched graph.
|
||||
*
|
||||
* This is an optional step, hence the saving of both the full
|
||||
* and pruned graph.
|
||||
*/
|
||||
template<typename T>
|
||||
void prune_and_save(path final_index_path_prefix, path full_index_path_prefix,
|
||||
path input_data_path,
|
||||
std::vector<std::vector<_u32>> stitched_graph,
|
||||
unsigned stitched_R,
|
||||
tsl::robin_map<std::string, _u32> label_entry_points,
|
||||
std::string universal_label, path label_data_path,
|
||||
unsigned num_threads) {
|
||||
size_t dimension, number_of_label_points;
|
||||
auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr);
|
||||
auto std_cout_buffer = std::cout.rdbuf(nullptr);
|
||||
auto pruning_index_timer = std::chrono::high_resolution_clock::now();
|
||||
|
||||
diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension);
|
||||
diskann::Index<T> index(diskann::Metric::L2, dimension,
|
||||
number_of_label_points, false, false);
|
||||
|
||||
// not searching this index, set search_l to 0
|
||||
index.load(full_index_path_prefix.c_str(), num_threads, 1);
|
||||
|
||||
diskann::Parameters paras;
|
||||
paras.Set<unsigned>("R", stitched_R);
|
||||
paras.Set<unsigned>(
|
||||
"C", 750); // maximum candidate set size during pruning procedure
|
||||
paras.Set<float>("alpha", 1.2);
|
||||
paras.Set<bool>("saturate_graph", 1);
|
||||
std::cout << "parsing labels" << std::endl;
|
||||
|
||||
index.prune_all_nbrs(paras);
|
||||
index.save((final_index_path_prefix).c_str());
|
||||
|
||||
diskann::cout.rdbuf(diskann_cout_buffer);
|
||||
std::cout.rdbuf(std_cout_buffer);
|
||||
std::chrono::duration<double> pruning_index_time =
|
||||
std::chrono::high_resolution_clock::now() - pruning_index_timer;
|
||||
std::cout << "pruning performed in " << pruning_index_time.count()
|
||||
<< " seconds\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
/*
|
||||
* Delete all temporary artifacts.
|
||||
* In the process of creating the stitched index, some temporary artifacts are
|
||||
* created:
|
||||
* 1. the separate bin files for each labels' points
|
||||
* 2. the separate diskANN indices built for each label
|
||||
* 3. the '.data' file created while generating the indices
|
||||
*/
|
||||
void clean_up_artifacts(path input_data_path, path final_index_path_prefix,
|
||||
label_set all_labels) {
|
||||
for (const auto &lbl : all_labels) {
|
||||
path curr_label_input_data_path(input_data_path + "_" + lbl);
|
||||
path curr_label_index_path(final_index_path_prefix + "_" + lbl);
|
||||
path curr_label_index_path_data(curr_label_index_path + ".data");
|
||||
|
||||
if (std::remove(curr_label_index_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_input_data_path.c_str()) != 0)
|
||||
throw;
|
||||
if (std::remove(curr_label_index_path_data.c_str()) != 0)
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// 1. handle cmdline inputs
|
||||
std::string data_type;
|
||||
path input_data_path, final_index_path_prefix, label_data_path;
|
||||
std::string universal_label;
|
||||
unsigned num_threads, R, L, stitched_R;
|
||||
float alpha;
|
||||
|
||||
auto index_timer = std::chrono::high_resolution_clock::now();
|
||||
handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix,
|
||||
label_data_path, universal_label, num_threads, R, L, stitched_R,
|
||||
alpha);
|
||||
|
||||
path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt";
|
||||
path labels_map_file = final_index_path_prefix + "_labels_map.txt";
|
||||
|
||||
convert_labels_string_to_int(label_data_path, labels_file_to_use,
|
||||
labels_map_file, universal_label);
|
||||
|
||||
// 2. parse label file and create necessary data structures
|
||||
std::vector<label_set> point_ids_to_labels;
|
||||
tsl::robin_map<std::string, _u32> labels_to_number_of_points;
|
||||
label_set all_labels;
|
||||
|
||||
std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) =
|
||||
parse_label_file(labels_file_to_use, universal_label);
|
||||
|
||||
// 3. for each label, make a separate data file
|
||||
tsl::robin_map<std::string, std::vector<_u32>> label_id_to_orig_id_map;
|
||||
_u32 total_number_of_points = point_ids_to_labels.size();
|
||||
|
||||
#ifndef _WINDOWS
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map = generate_label_specific_vector_files<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map = generate_label_specific_vector_files<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map = generate_label_specific_vector_files<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else
|
||||
throw;
|
||||
#else
|
||||
if (data_type == "uint8")
|
||||
label_id_to_orig_id_map =
|
||||
generate_label_specific_vector_files_compat<uint8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else if (data_type == "int8")
|
||||
label_id_to_orig_id_map =
|
||||
generate_label_specific_vector_files_compat<int8_t>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else if (data_type == "float")
|
||||
label_id_to_orig_id_map =
|
||||
generate_label_specific_vector_files_compat<float>(
|
||||
input_data_path, labels_to_number_of_points, point_ids_to_labels,
|
||||
all_labels);
|
||||
else
|
||||
throw;
|
||||
#endif
|
||||
|
||||
// 4. for each created data file, create a vanilla diskANN index
|
||||
if (data_type == "uint8")
|
||||
generate_label_indices<uint8_t>(input_data_path, final_index_path_prefix,
|
||||
all_labels, R, L, alpha, num_threads);
|
||||
else if (data_type == "int8")
|
||||
generate_label_indices<int8_t>(input_data_path, final_index_path_prefix,
|
||||
all_labels, R, L, alpha, num_threads);
|
||||
else if (data_type == "float")
|
||||
generate_label_indices<float>(input_data_path, final_index_path_prefix,
|
||||
all_labels, R, L, alpha, num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
// 5. "stitch" the indices together
|
||||
std::vector<std::vector<_u32>> stitched_graph;
|
||||
tsl::robin_map<std::string, _u32> label_entry_points;
|
||||
_u64 stitched_graph_size;
|
||||
|
||||
if (data_type == "uint8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<uint8_t>(
|
||||
final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points,
|
||||
label_id_to_orig_id_map);
|
||||
else if (data_type == "int8")
|
||||
std::tie(stitched_graph, stitched_graph_size) =
|
||||
stitch_label_indices<int8_t>(
|
||||
final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points,
|
||||
label_id_to_orig_id_map);
|
||||
else if (data_type == "float")
|
||||
std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices<float>(
|
||||
final_index_path_prefix, total_number_of_points, all_labels,
|
||||
labels_to_number_of_points, label_entry_points,
|
||||
label_id_to_orig_id_map);
|
||||
else
|
||||
throw;
|
||||
path full_index_path_prefix = final_index_path_prefix + "_full";
|
||||
// 5a. save the stitched graph to disk
|
||||
save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size,
|
||||
stitched_graph, label_entry_points, universal_label,
|
||||
labels_file_to_use);
|
||||
|
||||
// 6. run a prune on the stitched index, and save to disk
|
||||
if (data_type == "uint8")
|
||||
prune_and_save<uint8_t>(final_index_path_prefix, full_index_path_prefix,
|
||||
input_data_path, stitched_graph, stitched_R,
|
||||
label_entry_points, universal_label,
|
||||
labels_file_to_use, num_threads);
|
||||
else if (data_type == "int8")
|
||||
prune_and_save<int8_t>(final_index_path_prefix, full_index_path_prefix,
|
||||
input_data_path, stitched_graph, stitched_R,
|
||||
label_entry_points, universal_label,
|
||||
labels_file_to_use, num_threads);
|
||||
else if (data_type == "float")
|
||||
prune_and_save<float>(final_index_path_prefix, full_index_path_prefix,
|
||||
input_data_path, stitched_graph, stitched_R,
|
||||
label_entry_points, universal_label,
|
||||
labels_file_to_use, num_threads);
|
||||
else
|
||||
throw;
|
||||
|
||||
std::chrono::duration<double> index_time =
|
||||
std::chrono::high_resolution_clock::now() - index_timer;
|
||||
std::cout << "pruned/stitched graph generated in " << index_time.count()
|
||||
<< " seconds" << std::endl;
|
||||
|
||||
clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels);
|
||||
}
|
|
@ -47,7 +47,7 @@ void print_stats(std::string category, std::vector<float> percentiles,
|
|||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(diskann::Metric& metric,
|
||||
const std::string& index_path_prefix,
|
||||
const std::string& query_file, std::string& gt_file,
|
||||
|
@ -99,8 +99,8 @@ int search_disk_index(diskann::Metric& metric,
|
|||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T>(reader, metric));
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
|
@ -202,8 +202,8 @@ int search_disk_index(diskann::Metric& metric,
|
|||
std::vector<_u64> indices;
|
||||
std::vector<float> distances;
|
||||
_u32 res_count = _pFlashIndex->range_search(
|
||||
query + (i * query_aligned_dim), search_range, L, max_list_size,
|
||||
indices, distances, optimized_beamwidth, stats + i);
|
||||
query + (i * query_aligned_dim), search_range, L, max_list_size,
|
||||
indices, distances, optimized_beamwidth, stats + i);
|
||||
query_result_ids[test_id][i].reserve(res_count);
|
||||
query_result_ids[test_id][i].resize(res_count);
|
||||
for (_u32 idx = 0; idx < res_count; idx++)
|
||||
|
|
|
@ -12,7 +12,6 @@
|
|||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
|
||||
#include <cpprest/http_client.h>
|
||||
#include <restapi/common.h>
|
||||
|
||||
|
@ -65,47 +64,40 @@ void query_loop(const std::string& ip_addr_port, const std::string& query_file,
|
|||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string data_type, query_file, address;
|
||||
uint32_t num_queries;
|
||||
uint32_t l_search, k_value;
|
||||
std::string data_type, query_file, address;
|
||||
uint32_t num_queries;
|
||||
uint32_t l_search, k_value;
|
||||
|
||||
po::options_description desc{ "Arguments" };
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address",
|
||||
po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("query_file",
|
||||
po::value<std::string>(&query_file)->required(),
|
||||
"File containing the queries to search");
|
||||
desc.add_options()(
|
||||
"num_queries,Q",
|
||||
po::value<uint32_t>(&num_queries)->required(),
|
||||
"Number of queries to search");
|
||||
desc.add_options()(
|
||||
"l_search",
|
||||
po::value<uint32_t>(&l_search)->required(),
|
||||
"Value of L");
|
||||
desc.add_options()(
|
||||
"k_value,K",
|
||||
po::value<uint32_t>(&k_value)->default_value(10),
|
||||
"Value of K (default 10)");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("query_file",
|
||||
po::value<std::string>(&query_file)->required(),
|
||||
"File containing the queries to search");
|
||||
desc.add_options()("num_queries,Q",
|
||||
po::value<uint32_t>(&num_queries)->required(),
|
||||
"Number of queries to search");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(),
|
||||
"Value of L");
|
||||
desc.add_options()("k_value,K",
|
||||
po::value<uint32_t>(&k_value)->default_value(10),
|
||||
"Value of K (default 10)");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float")) {
|
||||
query_loop<float>(address, query_file, num_queries, l_search, k_value);
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include <codecvt>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
|
@ -37,99 +36,96 @@ void teardown(const utility::string_t& address) {
|
|||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
||||
uint32_t num_threads;
|
||||
uint32_t l_search;
|
||||
std::string data_type, index_file, data_file, address, dist_fn, tags_file;
|
||||
uint32_t num_threads;
|
||||
uint32_t l_search;
|
||||
|
||||
po::options_description desc{ "Arguments" };
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address",
|
||||
po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("data_file",
|
||||
po::value<std::string>(&data_file)->required(),
|
||||
"File containing the data found in the index");
|
||||
desc.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&index_file)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()(
|
||||
"num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->required(),
|
||||
"Number of threads used for building index");
|
||||
desc.add_options()(
|
||||
"l_search",
|
||||
po::value<uint32_t>(&l_search)->required(),
|
||||
"Value of L");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file",
|
||||
po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else {
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("data_file",
|
||||
po::value<std::string>(&data_file)->required(),
|
||||
"File containing the data found in the index");
|
||||
desc.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&index_file)->required(),
|
||||
"Path prefix for saving index file components");
|
||||
desc.add_options()("num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->required(),
|
||||
"Number of threads used for building index");
|
||||
desc.add_options()("l_search", po::value<uint32_t>(&l_search)->required(),
|
||||
"Value of L");
|
||||
desc.add_options()("dist_fn",
|
||||
po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()(
|
||||
"tags_file",
|
||||
po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else {
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float")) {
|
||||
auto searcher =
|
||||
if (data_type == std::string("float")) {
|
||||
auto searcher =
|
||||
std::unique_ptr<diskann::BaseSearch>(new diskann::InMemorySearch<float>(
|
||||
data_file, index_file, tags_file, metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("int8")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file,
|
||||
metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file,
|
||||
metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else {
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
}
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("int8")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<int8_t>(data_file, index_file, tags_file,
|
||||
metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::InMemorySearch<uint8_t>(data_file, index_file, tags_file,
|
||||
metric, num_threads, l_search));
|
||||
g_inMemorySearch.push_back(std::move(searcher));
|
||||
} else {
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
}
|
||||
|
||||
while (1) {
|
||||
try {
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit") {
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
while (1) {
|
||||
try {
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit") {
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,134 +37,135 @@ void teardown(const utility::string_t& address) {
|
|||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
std::string data_type, index_prefix_paths, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{ "Arguments" };
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("address",
|
||||
po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("index_prefix_paths",
|
||||
po::value<std::string>(&index_prefix_paths)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()(
|
||||
"num_nodes_to_cache",
|
||||
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()(
|
||||
"num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file",
|
||||
po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("index_prefix_paths",
|
||||
po::value<std::string>(&index_prefix_paths)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()(
|
||||
"num_nodes_to_cache",
|
||||
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
"Number of nodes to cache during search");
|
||||
desc.add_options()(
|
||||
"num_threads,T",
|
||||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn",
|
||||
po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()(
|
||||
"tags_file",
|
||||
po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
}
|
||||
catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else {
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
||||
std::ifstream index_in(index_prefix_paths);
|
||||
if (!index_in.is_open()) {
|
||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream tags_in(tags_file);
|
||||
if (!tags_in.is_open()) {
|
||||
std::cerr << "Could not open " << tags_file << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string prefix, tagfile;
|
||||
while (std::getline(index_in, prefix)) {
|
||||
if (std::getline(tags_in, tagfile)) {
|
||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
||||
} else {
|
||||
std::cerr << "The number of tags specified does not match the number of "
|
||||
"indices specified" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
index_in.close();
|
||||
tags_in.close();
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else {
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (data_type == std::string("float")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(index_tag.first.c_str(),
|
||||
num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
|
||||
} else if (data_type == std::string("int8")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<int8_t>(index_tag.first.c_str(),
|
||||
num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(index_tag.first.c_str(),
|
||||
num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
std::vector<std::pair<std::string, std::string>> index_tag_paths;
|
||||
std::ifstream index_in(index_prefix_paths);
|
||||
if (!index_in.is_open()) {
|
||||
std::cerr << "Could not open " << index_prefix_paths << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::ifstream tags_in(tags_file);
|
||||
if (!tags_in.is_open()) {
|
||||
std::cerr << "Could not open " << tags_file << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
std::string prefix, tagfile;
|
||||
while (std::getline(index_in, prefix)) {
|
||||
if (std::getline(tags_in, tagfile)) {
|
||||
index_tag_paths.push_back(std::make_pair(prefix, tagfile));
|
||||
} else {
|
||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
||||
exit(-1);
|
||||
std::cerr << "The number of tags specified does not match the number of "
|
||||
"indices specified"
|
||||
<< std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
index_in.close();
|
||||
tags_in.close();
|
||||
|
||||
if (data_type == std::string("float")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(index_tag.first.c_str(),
|
||||
num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
|
||||
while (1) {
|
||||
try {
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit") {
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
} else if (data_type == std::string("int8")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<int8_t>(index_tag.first.c_str(),
|
||||
num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
for (auto& index_tag : index_tag_paths) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(
|
||||
index_tag.first.c_str(), num_nodes_to_cache, num_threads,
|
||||
index_tag.second.c_str(), metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
}
|
||||
} else {
|
||||
std::cerr << "Unsupported data type " << data_type << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
while (1) {
|
||||
try {
|
||||
setup(address, data_type);
|
||||
std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl;
|
||||
std::string line;
|
||||
std::getline(std::cin, line);
|
||||
if (line == "exit") {
|
||||
teardown(address);
|
||||
g_httpServer->close().wait();
|
||||
exit(0);
|
||||
}
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << "Exception occurred: " << ex.what() << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
} catch (...) {
|
||||
std::cerr << "Unknown exception occurreed" << std::endl;
|
||||
std::cerr << "Restarting HTTP server";
|
||||
teardown(address);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,8 +10,6 @@
|
|||
#include <boost/program_options.hpp>
|
||||
#include <omp.h>
|
||||
|
||||
|
||||
|
||||
#include <restapi/server.h>
|
||||
|
||||
using namespace diskann;
|
||||
|
@ -22,7 +20,7 @@ std::vector<std::unique_ptr<diskann::BaseSearch>> g_ssdSearch;
|
|||
|
||||
void setup(const utility::string_t& address, const std::string& typestring) {
|
||||
web::http::uri_builder uriBldr(address);
|
||||
auto uri = uriBldr.to_uri();
|
||||
auto uri = uriBldr.to_uri();
|
||||
|
||||
std::cout << "Attempting to start server on " << uri.to_string() << std::endl;
|
||||
|
||||
|
@ -40,21 +38,20 @@ void teardown(const utility::string_t& address) {
|
|||
|
||||
int main(int argc, char* argv[]) {
|
||||
std::string data_type, index_path_prefix, address, dist_fn, tags_file;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
uint32_t num_nodes_to_cache;
|
||||
uint32_t num_threads;
|
||||
|
||||
po::options_description desc{ "Arguments" };
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("data_type",
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address",
|
||||
po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
po::value<std::string>(&data_type)->required(),
|
||||
"data type <int8/uint8/float>");
|
||||
desc.add_options()("address", po::value<std::string>(&address)->required(),
|
||||
"Web server address");
|
||||
desc.add_options()("index_path_prefix",
|
||||
po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
po::value<std::string>(&index_path_prefix)->required(),
|
||||
"Path prefix for loading index file components");
|
||||
desc.add_options()(
|
||||
"num_nodes_to_cache",
|
||||
po::value<uint32_t>(&num_nodes_to_cache)->default_value(0),
|
||||
|
@ -64,52 +61,52 @@ int main(int argc, char* argv[]) {
|
|||
po::value<uint32_t>(&num_threads)->default_value(omp_get_num_procs()),
|
||||
"Number of threads used for building index (defaults to "
|
||||
"omp_get_num_procs())");
|
||||
desc.add_options()("dist_fn", po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()("tags_file",
|
||||
desc.add_options()("dist_fn",
|
||||
po::value<std::string>(&dist_fn)->default_value("l2"),
|
||||
"distance function <l2/mips>");
|
||||
desc.add_options()(
|
||||
"tags_file",
|
||||
po::value<std::string>(&tags_file)->default_value(std::string()),
|
||||
"Tags file location");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
std::cerr << ex.what() << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
diskann::Metric metric;
|
||||
if (dist_fn == std::string("l2"))
|
||||
metric = diskann::Metric::L2;
|
||||
metric = diskann::Metric::L2;
|
||||
else if (dist_fn == std::string("mips"))
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
metric = diskann::Metric::INNER_PRODUCT;
|
||||
else {
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
std::cout << "Error. Only l2 and mips distance functions are supported"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
if (data_type == std::string("float")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<float>(
|
||||
index_path_prefix, num_nodes_to_cache, num_threads,
|
||||
tags_file, metric));
|
||||
new diskann::PQFlashSearch<float>(index_path_prefix, num_nodes_to_cache,
|
||||
num_threads, tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("int8")) {
|
||||
auto searcher =
|
||||
std::unique_ptr<diskann::BaseSearch>(new diskann::PQFlashSearch<int8_t>(
|
||||
index_path_prefix, num_nodes_to_cache, num_threads,
|
||||
tags_file, metric));
|
||||
index_path_prefix, num_nodes_to_cache, num_threads, tags_file,
|
||||
metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
auto searcher = std::unique_ptr<diskann::BaseSearch>(
|
||||
new diskann::PQFlashSearch<uint8_t>(
|
||||
index_path_prefix, num_nodes_to_cache, num_threads,
|
||||
tags_file, metric));
|
||||
new diskann::PQFlashSearch<uint8_t>(index_path_prefix,
|
||||
num_nodes_to_cache, num_threads,
|
||||
tags_file, metric));
|
||||
g_ssdSearch.push_back(std::move(searcher));
|
||||
} else {
|
||||
std::cerr << "Unsupported data type " << argv[2] << std::endl;
|
||||
|
|
|
@ -44,14 +44,15 @@ void print_stats(std::string category, std::vector<float> percentiles,
|
|||
diskann::cout << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
int search_disk_index(
|
||||
diskann::Metric& metric, const std::string& index_path_prefix,
|
||||
const std::string& result_output_prefix, const std::string& query_file,
|
||||
std::string& gt_file, const unsigned num_threads, const unsigned recall_at,
|
||||
const unsigned beamwidth, const unsigned num_nodes_to_cache,
|
||||
const _u32 search_io_limit, const std::vector<unsigned>& Lvec,
|
||||
const bool use_reorder_data, const float fail_if_recall_below) {
|
||||
const float fail_if_recall_below, const bool use_reorder_data = false,
|
||||
const std::string& filter_label = "") {
|
||||
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
|
||||
if (beamwidth <= 0)
|
||||
diskann::cout << "beamwidth to be optimized for each L value" << std::flush;
|
||||
|
@ -62,6 +63,10 @@ int search_disk_index(
|
|||
else
|
||||
diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl;
|
||||
|
||||
bool filtered_search = false;
|
||||
if (filter_label != "")
|
||||
filtered_search = true;
|
||||
|
||||
std::string warmup_query_file = index_path_prefix + "_sample_data.bin";
|
||||
|
||||
// load query bin
|
||||
|
@ -95,8 +100,8 @@ int search_disk_index(
|
|||
reader.reset(new LinuxAlignedFileReader());
|
||||
#endif
|
||||
|
||||
std::unique_ptr<diskann::PQFlashIndex<T>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T>(reader, metric));
|
||||
std::unique_ptr<diskann::PQFlashIndex<T, LabelT>> _pFlashIndex(
|
||||
new diskann::PQFlashIndex<T, LabelT>(reader, metric));
|
||||
|
||||
int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str());
|
||||
|
||||
|
@ -207,11 +212,22 @@ int search_disk_index(
|
|||
|
||||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (_s64 i = 0; i < (int64_t) query_num; i++) {
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, search_io_limit, use_reorder_data, stats + i);
|
||||
if (!filtered_search) {
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, use_reorder_data, stats + i);
|
||||
} else {
|
||||
LabelT label_for_search =
|
||||
_pFlashIndex->get_converted_label(filter_label);
|
||||
_pFlashIndex->cached_beam_search(
|
||||
query + (i * query_aligned_dim), recall_at, L,
|
||||
query_result_ids_64.data() + (i * recall_at),
|
||||
query_result_dists[test_id].data() + (i * recall_at),
|
||||
optimized_beamwidth, true, label_for_search, use_reorder_data,
|
||||
stats + i);
|
||||
}
|
||||
}
|
||||
auto e = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = e - s;
|
||||
|
@ -282,7 +298,7 @@ int search_disk_index(
|
|||
|
||||
int main(int argc, char** argv) {
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path_prefix,
|
||||
query_file, gt_file;
|
||||
query_file, gt_file, filter_label, label_type;
|
||||
unsigned num_threads, K, W, num_nodes_to_cache, search_io_limit;
|
||||
std::vector<unsigned> Lvec;
|
||||
bool use_reorder_data = false;
|
||||
|
@ -333,6 +349,15 @@ int main(int argc, char** argv) {
|
|||
po::bool_switch()->default_value(false),
|
||||
"Include full precision data in the index. Use only in "
|
||||
"conjuction with compressed data on SSD.");
|
||||
desc.add_options()(
|
||||
"filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
"Filter Label for Filtered Search");
|
||||
desc.add_options()(
|
||||
"label_type",
|
||||
po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
desc.add_options()(
|
||||
"fail_if_recall_below",
|
||||
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
|
||||
|
@ -382,29 +407,52 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
try {
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
use_reorder_data, fail_if_recall_below);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
use_reorder_data, fail_if_recall_below);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
use_reorder_data, fail_if_recall_below);
|
||||
else {
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
if (filter_label != "" && label_type == "ushort") {
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else {
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
if (data_type == std::string("float"))
|
||||
return search_disk_index<float>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else if (data_type == std::string("int8"))
|
||||
return search_disk_index<int8_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else if (data_type == std::string("uint8"))
|
||||
return search_disk_index<uint8_t>(
|
||||
metric, index_path_prefix, result_path_prefix, query_file, gt_file,
|
||||
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
|
||||
fail_if_recall_below, use_reorder_data, filter_label);
|
||||
else {
|
||||
std::cerr << "Unsupported data type. Use float or int8 or uint8"
|
||||
<< std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} catch (const std::exception& e) {
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Index search failed." << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
namespace po = boost::program_options;
|
||||
|
||||
template<typename T>
|
||||
template<typename T, typename LabelT = uint32_t>
|
||||
int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
||||
const std::string& result_path_prefix,
|
||||
const std::string& query_file,
|
||||
|
@ -32,7 +32,8 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
|||
const bool print_all_recalls,
|
||||
const std::vector<unsigned>& Lvec, const bool dynamic,
|
||||
const bool tags, const bool show_qps_per_thread,
|
||||
const float fail_if_recall_below) {
|
||||
const std::string& filter_label,
|
||||
const float fail_if_recall_below) {
|
||||
// Load the query file
|
||||
T* query = nullptr;
|
||||
unsigned* gt_ids = nullptr;
|
||||
|
@ -54,8 +55,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
|||
diskann::cout << " Truthset file " << truthset_file
|
||||
<< " not found. Not computing recall." << std::endl;
|
||||
}
|
||||
|
||||
bool filtered_search = false;
|
||||
if (filter_label != "") {
|
||||
filtered_search = true;
|
||||
}
|
||||
|
||||
using TagT = uint32_t;
|
||||
diskann::Index<T, TagT> index(metric, query_dim, 0, dynamic, tags);
|
||||
diskann::Index<T, TagT, LabelT> index(metric, query_dim, 0, dynamic, tags);
|
||||
std::cout << "Index class instantiated" << std::endl;
|
||||
index.load(index_path.c_str(), num_threads,
|
||||
*(std::max_element(Lvec.begin(), Lvec.end())));
|
||||
|
@ -117,6 +124,7 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
|||
}
|
||||
|
||||
query_result_ids[test_id].resize(recall_at * query_num);
|
||||
query_result_dists[test_id].resize(recall_at * query_num);
|
||||
std::vector<T*> res = std::vector<T*>();
|
||||
|
||||
auto s = std::chrono::high_resolution_clock::now();
|
||||
|
@ -124,7 +132,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
|||
#pragma omp parallel for schedule(dynamic, 1)
|
||||
for (int64_t i = 0; i < (int64_t) query_num; i++) {
|
||||
auto qs = std::chrono::high_resolution_clock::now();
|
||||
if (metric == diskann::FAST_L2) {
|
||||
if (filtered_search) {
|
||||
LabelT filter_label_as_num = index.get_converted_label(filter_label);
|
||||
auto retval = index.search_with_filters(
|
||||
query + i * query_aligned_dim, filter_label_as_num, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at,
|
||||
query_result_dists[test_id].data() + i * recall_at);
|
||||
cmp_stats[i] = retval.second;
|
||||
} else if (metric == diskann::FAST_L2) {
|
||||
index.search_with_optimized_layout(
|
||||
query + i * query_aligned_dim, recall_at, L,
|
||||
query_result_ids[test_id].data() + i * recall_at);
|
||||
|
@ -212,10 +227,9 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path,
|
|||
return best_recall >= fail_if_recall_below ? 0 : -1;
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string data_type, dist_fn, index_path_prefix, result_path, query_file,
|
||||
gt_file;
|
||||
gt_file, filter_label, label_type;
|
||||
unsigned num_threads, K;
|
||||
std::vector<unsigned> Lvec;
|
||||
bool print_all_recalls, dynamic, tags, show_qps_per_thread;
|
||||
|
@ -238,6 +252,15 @@ int main(int argc, char** argv) {
|
|||
desc.add_options()("query_file",
|
||||
po::value<std::string>(&query_file)->required(),
|
||||
"Query file in binary format");
|
||||
desc.add_options()(
|
||||
"filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(std::string("")),
|
||||
"Filter Label for Filtered Search");
|
||||
desc.add_options()(
|
||||
"label_type",
|
||||
po::value<std::string>(&label_type)->default_value("uint"),
|
||||
"Storage type of Labels <uint/ushort>, default value is uint which "
|
||||
"will consume memory 4 bytes per filter");
|
||||
desc.add_options()(
|
||||
"gt_file",
|
||||
po::value<std::string>(>_file)->default_value(std::string("null")),
|
||||
|
@ -313,26 +336,46 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
try {
|
||||
if (data_type == std::string("int8")) {
|
||||
return search_memory_index<int8_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, fail_if_recall_below);
|
||||
}
|
||||
|
||||
else if (data_type == std::string("uint8")) {
|
||||
return search_memory_index<uint8_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, fail_if_recall_below);
|
||||
} else if (data_type == std::string("float")) {
|
||||
return search_memory_index<float>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, fail_if_recall_below);
|
||||
if (filter_label != "" && label_type == "ushort") {
|
||||
if (data_type == std::string("int8")) {
|
||||
return search_memory_index<int8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
return search_memory_index<uint8_t, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else if (data_type == std::string("float")) {
|
||||
return search_memory_index<float, uint16_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else {
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
} else {
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
if (data_type == std::string("int8")) {
|
||||
return search_memory_index<int8_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else if (data_type == std::string("uint8")) {
|
||||
return search_memory_index<uint8_t>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else if (data_type == std::string("float")) {
|
||||
return search_memory_index<float>(
|
||||
metric, index_path_prefix, result_path, query_file, gt_file,
|
||||
num_threads, K, print_all_recalls, Lvec, dynamic, tags,
|
||||
show_qps_per_thread, filter_label, fail_if_recall_below);
|
||||
} else {
|
||||
std::cout << "Unsupported type. Use float/int8/uint8" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
} catch (std::exception& e) {
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
|
|
|
@ -160,6 +160,8 @@ void build_incremental_index(
|
|||
params.Set<bool>("saturate_graph", saturate_graph);
|
||||
params.Set<unsigned>("num_rnds", 1);
|
||||
params.Set<unsigned>("num_threads", thread_count);
|
||||
params.Set<int>("Lf", 0); // TODO: get this from params and default to some
|
||||
// value to make it backward compatible.
|
||||
|
||||
size_t dim, aligned_dim;
|
||||
size_t num_points;
|
||||
|
|
|
@ -83,9 +83,10 @@ std::string get_save_filename(const std::string& save_path,
|
|||
return final_path;
|
||||
}
|
||||
|
||||
template<typename T, typename TagT>
|
||||
void insert_next_batch(diskann::Index<T, TagT>& index, size_t start, size_t end,
|
||||
size_t insert_threads, T* data, size_t aligned_dim) {
|
||||
template<typename T, typename TagT, typename LabelT>
|
||||
void insert_next_batch(diskann::Index<T, TagT, LabelT>& index, size_t start,
|
||||
size_t end, size_t insert_threads, T* data,
|
||||
size_t aligned_dim) {
|
||||
try {
|
||||
diskann::Timer insert_timer;
|
||||
std::cout << std::endl
|
||||
|
@ -116,8 +117,8 @@ void insert_next_batch(diskann::Index<T, TagT>& index, size_t start, size_t end,
|
|||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TagT>
|
||||
void delete_and_consolidate(diskann::Index<T, TagT>& index,
|
||||
template<typename T, typename TagT, typename LabelT>
|
||||
void delete_and_consolidate(diskann::Index<T, TagT, LabelT>& index,
|
||||
diskann::Parameters& delete_params, size_t start,
|
||||
size_t end) {
|
||||
try {
|
||||
|
@ -189,6 +190,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
|
|||
params.Set<bool>("saturate_graph", saturate_graph);
|
||||
params.Set<unsigned>("num_rnds", 1);
|
||||
params.Set<unsigned>("num_threads", insert_threads);
|
||||
params.Set<unsigned>("Lf", 0);
|
||||
diskann::Parameters delete_params;
|
||||
delete_params.Set<unsigned>("L", L);
|
||||
delete_params.Set<unsigned>("R", R);
|
||||
|
@ -222,6 +224,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
|
|||
__FUNCSIG__, __FILE__, __LINE__);
|
||||
|
||||
using TagT = uint32_t;
|
||||
using LabelT = uint32_t;
|
||||
unsigned num_frozen = 1;
|
||||
const bool enable_tags = true;
|
||||
|
||||
|
@ -232,9 +235,9 @@ void build_incremental_index(const std::string& data_path, const unsigned L,
|
|||
std::cout << "Overriding num_frozen to" << num_frozen << std::endl;
|
||||
}
|
||||
|
||||
diskann::Index<T, TagT> index(diskann::L2, dim,
|
||||
active_window + 4 * consolidate_interval, true,
|
||||
params, params, enable_tags, true);
|
||||
diskann::Index<T, TagT, LabelT> index(
|
||||
diskann::L2, dim, active_window + 4 * consolidate_interval, true, params,
|
||||
params, enable_tags, true);
|
||||
index.set_start_point_at_random(static_cast<T>(start_point_norm));
|
||||
index.enable_delete();
|
||||
|
||||
|
|
|
@ -66,3 +66,8 @@ target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK
|
|||
add_executable(create_disk_layout create_disk_layout.cpp)
|
||||
target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS})
|
||||
|
||||
add_executable(generate_synthetic_labels generate_synthetic_labels.cpp)
|
||||
target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options)
|
||||
|
||||
add_executable(stats_label_data stats_label_data.cpp)
|
||||
target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options)
|
||||
|
|
|
@ -302,6 +302,66 @@ inline void load_bin_as_float(const char *filename, float *&data,
|
|||
std::cout << "Finished converting part data to float." << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline std::vector<size_t> load_filtered_bin_as_float(
|
||||
const char *filename, float *&data, size_t &npts, size_t &ndims,
|
||||
int part_num, const char *label_file, const std::string &filter_label,
|
||||
const std::string &universal_label, size_t &npoints_filt,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels) {
|
||||
std::ifstream reader(filename, std::ios::binary);
|
||||
if (reader.fail()) {
|
||||
throw diskann::ANNException(std::string("Failed to open file ") + filename,
|
||||
-1);
|
||||
}
|
||||
|
||||
std::cout << "Reading bin file " << filename << " ...\n";
|
||||
int npts_i32, ndims_i32;
|
||||
std::vector<size_t> rev_map;
|
||||
reader.read((char *) &npts_i32, sizeof(int));
|
||||
reader.read((char *) &ndims_i32, sizeof(int));
|
||||
uint64_t start_id = part_num * PARTSIZE;
|
||||
uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t) npts_i32);
|
||||
npts = end_id - start_id;
|
||||
ndims = (unsigned) ndims_i32;
|
||||
uint64_t nptsuint64_t = (uint64_t) npts;
|
||||
uint64_t ndimsuint64_t = (uint64_t) ndims;
|
||||
npoints_filt = 0;
|
||||
std::cout << "#pts in part = " << npts << ", #dims = " << ndims
|
||||
<< ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B"
|
||||
<< std::endl;
|
||||
std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl;
|
||||
reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t),
|
||||
std::ios::beg);
|
||||
|
||||
T *data_T = new T[nptsuint64_t * ndimsuint64_t];
|
||||
reader.read((char *) data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t);
|
||||
std::cout << "Finished reading part of the bin file." << std::endl;
|
||||
reader.close();
|
||||
|
||||
data = aligned_malloc<float>(nptsuint64_t * ndimsuint64_t, ALIGNMENT);
|
||||
|
||||
for (int64_t i = 0; i < (int64_t) nptsuint64_t; i++) {
|
||||
if (std::find(pts_to_labels[start_id + i].begin(),
|
||||
pts_to_labels[start_id + i].end(),
|
||||
filter_label) != pts_to_labels[start_id + i].end() ||
|
||||
std::find(pts_to_labels[start_id + i].begin(),
|
||||
pts_to_labels[start_id + i].end(),
|
||||
universal_label) != pts_to_labels[start_id + i].end()) {
|
||||
rev_map.push_back(start_id + i);
|
||||
for (int64_t j = 0; j < (int64_t) ndimsuint64_t; j++) {
|
||||
float cur_val_float = (float) data_T[i * ndimsuint64_t + j];
|
||||
std::memcpy((char *) (data + npoints_filt * ndimsuint64_t + j),
|
||||
(char *) &cur_val_float, sizeof(float));
|
||||
}
|
||||
npoints_filt++;
|
||||
}
|
||||
}
|
||||
delete[] data_T;
|
||||
std::cout << "Finished converting part data to float.. identified "
|
||||
<< npoints_filt << " points matching the filter." << std::endl;
|
||||
return rev_map;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
inline void save_bin(const std::string filename, T *data, size_t npts,
|
||||
size_t ndims) {
|
||||
|
@ -334,19 +394,51 @@ inline void save_groundtruth_as_one_file(const std::string filename,
|
|||
<< 2 * npts * ndims * sizeof(unsigned) + 2 * sizeof(int) << "B"
|
||||
<< std::endl;
|
||||
|
||||
// data = new T[npts_u64 * ndims_u64];
|
||||
writer.write((char *) data, npts * ndims * sizeof(uint32_t));
|
||||
writer.write((char *) distances, npts * ndims * sizeof(float));
|
||||
writer.close();
|
||||
std::cout << "Finished writing truthset" << std::endl;
|
||||
}
|
||||
|
||||
inline void parse_label_file_into_vec(
|
||||
size_t &line_cnt, const std::string &map_file,
|
||||
std::vector<std::vector<std::string>> &pts_to_labels) {
|
||||
std::ifstream infile(map_file);
|
||||
std::string line, token;
|
||||
std::set<std::string> labels;
|
||||
infile.clear();
|
||||
infile.seekg(0, std::ios::beg);
|
||||
while (std::getline(infile, line)) {
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> lbls(0);
|
||||
|
||||
getline(iss, token, '\t');
|
||||
std::istringstream new_iss(token);
|
||||
while (getline(new_iss, token, ',')) {
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
lbls.push_back(token);
|
||||
labels.insert(token);
|
||||
}
|
||||
if (lbls.size() <= 0) {
|
||||
std::cout << "No label found";
|
||||
exit(-1);
|
||||
}
|
||||
std::sort(lbls.begin(), lbls.end());
|
||||
pts_to_labels.push_back(lbls);
|
||||
}
|
||||
std::cout << "Identified " << labels.size()
|
||||
<< " distinct label(s), and populated labels for "
|
||||
<< pts_to_labels.size() << " points" << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
int aux_main(const std::string &base_file, const std::string &query_file,
|
||||
const std::string >_file, size_t k,
|
||||
const diskann::Metric &metric,
|
||||
const std::string &tags_file = std::string("")) {
|
||||
size_t npoints, nqueries, dim;
|
||||
int aux_main(const std::string &base_file, const std::string &label_file,
|
||||
const std::string &query_file, const std::string >_file,
|
||||
size_t k, const std::string &filter_label,
|
||||
const std::string &universal_label, const diskann::Metric &metric,
|
||||
const std::string &tags_file = std::string("")) {
|
||||
size_t npoints, nqueries, dim, npoints_filt;
|
||||
|
||||
float *base_data;
|
||||
float *query_data;
|
||||
|
@ -392,25 +484,50 @@ int aux_main(const std::string &base_file, const std::string &query_file,
|
|||
int *closest_points = new int[nqueries * k];
|
||||
float *dist_closest_points = new float[nqueries * k];
|
||||
|
||||
std::vector<std::vector<std::string>> pts_to_labels;
|
||||
if (filter_label != "")
|
||||
parse_label_file_into_vec(npoints, label_file, pts_to_labels);
|
||||
std::vector<size_t> rev_map;
|
||||
|
||||
for (int p = 0; p < num_parts; p++) {
|
||||
size_t start_id = p * PARTSIZE;
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
if (filter_label == "") {
|
||||
load_bin_as_float<T>(base_file.c_str(), base_data, npoints, dim, p);
|
||||
} else {
|
||||
rev_map = load_filtered_bin_as_float<T>(
|
||||
base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(),
|
||||
filter_label, universal_label, npoints_filt, pts_to_labels);
|
||||
}
|
||||
int *closest_points_part = new int[nqueries * k];
|
||||
float *dist_closest_points_part = new float[nqueries * k];
|
||||
|
||||
auto nr = std::min(npoints, k);
|
||||
|
||||
exact_knn(dim, nr, closest_points_part, dist_closest_points_part, npoints,
|
||||
base_data, nqueries, query_data, metric);
|
||||
_u32 part_k;
|
||||
if (filter_label == "") {
|
||||
part_k = k < npoints ? k : npoints;
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part,
|
||||
npoints, base_data, nqueries, query_data, metric);
|
||||
} else {
|
||||
part_k = k < npoints_filt ? k : npoints_filt;
|
||||
if (npoints_filt > 0) {
|
||||
exact_knn(dim, part_k, closest_points_part, dist_closest_points_part,
|
||||
npoints_filt, base_data, nqueries, query_data, metric);
|
||||
}
|
||||
}
|
||||
|
||||
for (_u64 i = 0; i < nqueries; i++) {
|
||||
for (_u64 j = 0; j < nr; j++) {
|
||||
for (_u64 j = 0; j < part_k; j++) {
|
||||
if (tags_enabled)
|
||||
if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0)
|
||||
continue;
|
||||
results[i].push_back(std::make_pair(
|
||||
(uint32_t) (closest_points_part[i * nr + j] + start_id),
|
||||
dist_closest_points_part[i * nr + j]));
|
||||
if (filter_label == "") {
|
||||
results[i].push_back(std::make_pair(
|
||||
(uint32_t) (closest_points_part[i * part_k + j] + start_id),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
} else {
|
||||
results[i].push_back(std::make_pair(
|
||||
(uint32_t) (rev_map[closest_points_part[i * part_k + j]]),
|
||||
dist_closest_points_part[i * part_k + j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -455,8 +572,9 @@ int aux_main(const std::string &base_file, const std::string &query_file,
|
|||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file;
|
||||
uint64_t K;
|
||||
std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file,
|
||||
label_file, filter_label, universal_label;
|
||||
uint64_t K;
|
||||
|
||||
try {
|
||||
po::options_description desc{"Arguments"};
|
||||
|
@ -474,6 +592,17 @@ int main(int argc, char **argv) {
|
|||
desc.add_options()("query_file",
|
||||
po::value<std::string>(&query_file)->required(),
|
||||
"File containing the query vectors in binary format");
|
||||
desc.add_options()("label_file",
|
||||
po::value<std::string>(&label_file)->default_value(""),
|
||||
"Input labels file in txt format if present");
|
||||
desc.add_options()("filter_label",
|
||||
po::value<std::string>(&filter_label)->default_value(""),
|
||||
"Input filter label if doing filtered groundtruth");
|
||||
desc.add_options()(
|
||||
"universal_label",
|
||||
po::value<std::string>(&universal_label)->default_value(""),
|
||||
"Universal label, if using it, only in conjunction with label_file");
|
||||
|
||||
desc.add_options()(
|
||||
"gt_file", po::value<std::string>(>_file)->required(),
|
||||
"File name for the writing ground truth in binary format");
|
||||
|
@ -518,11 +647,14 @@ int main(int argc, char **argv) {
|
|||
|
||||
try {
|
||||
if (data_type == std::string("float"))
|
||||
aux_main<float>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
aux_main<float>(base_file, label_file, query_file, gt_file, K,
|
||||
filter_label, universal_label, metric, tags_file);
|
||||
if (data_type == std::string("int8"))
|
||||
aux_main<int8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
aux_main<int8_t>(base_file, label_file, query_file, gt_file, K,
|
||||
filter_label, universal_label, metric, tags_file);
|
||||
if (data_type == std::string("uint8"))
|
||||
aux_main<uint8_t>(base_file, query_file, gt_file, K, metric, tags_file);
|
||||
aux_main<uint8_t>(base_file, label_file, query_file, gt_file, K,
|
||||
filter_label, universal_label, metric, tags_file);
|
||||
} catch (const std::exception &e) {
|
||||
std::cout << std::string(e.what()) << std::endl;
|
||||
diskann::cerr << "Compute GT failed." << std::endl;
|
||||
|
|
|
@ -26,8 +26,9 @@ namespace po = boost::program_options;
|
|||
template<typename T>
|
||||
void bfs_count(const std::string& index_path, unsigned data_dims) {
|
||||
using TagT = uint32_t;
|
||||
diskann::Index<T, TagT> index(diskann::Metric::L2, data_dims, 0, false,
|
||||
false);
|
||||
using LabelT = uint32_t;
|
||||
diskann::Index<T, TagT, LabelT> index(diskann::Metric::L2, data_dims, 0,
|
||||
false, false);
|
||||
std::cout << "Index class instantiated" << std::endl;
|
||||
index.load(index_path.c_str(), 1, 100);
|
||||
std::cout << "Index loaded" << std::endl;
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <boost/program_options.hpp>
|
||||
#include <math.h>
|
||||
#include <cmath>
|
||||
#include "utils.h"
|
||||
|
||||
namespace po = boost::program_options;
|
||||
class ZipfDistribution {
|
||||
public:
|
||||
ZipfDistribution(int num_points, int num_labels)
|
||||
: uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)),
|
||||
num_points(num_points), num_labels(num_labels) {
|
||||
}
|
||||
|
||||
std::unordered_map<int, int> createDistributionMap() {
|
||||
std::unordered_map<int, int> map;
|
||||
int primary_label_freq = ceil(num_points * distribution_factor);
|
||||
for (int i{1}; i < num_labels + 1; i++) {
|
||||
map[i] = ceil(primary_label_freq / i);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
int writeDistribution(std::ofstream& outfile) {
|
||||
auto distribution_map = createDistributionMap();
|
||||
auto primary_label_frequency = num_points * distribution_factor;
|
||||
for (int i{0}; i < num_points; i++) {
|
||||
bool label_written = false;
|
||||
for (auto it = distribution_map.cbegin(), next_it = it;
|
||||
it != distribution_map.cend(); it = next_it) {
|
||||
next_it++;
|
||||
auto label_selection_probability = std::bernoulli_distribution(
|
||||
distribution_factor / (double) it->first);
|
||||
if (label_selection_probability(rand_engine)) {
|
||||
if (label_written) {
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << it->first;
|
||||
label_written = true;
|
||||
// remove label from map if we have used all labels
|
||||
distribution_map[it->first] -= 1;
|
||||
if (distribution_map[it->first] == 0) {
|
||||
distribution_map.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!label_written) {
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1) {
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int writeDistribution(std::string filename) {
|
||||
std::ofstream outfile(filename);
|
||||
if (!outfile.is_open()) {
|
||||
std::cerr << "Error: could not open output file " << filename << '\n';
|
||||
return -1;
|
||||
}
|
||||
writeDistribution(outfile);
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
private:
|
||||
int num_labels;
|
||||
const int num_points;
|
||||
const double distribution_factor = 0.7;
|
||||
std::knuth_b rand_engine;
|
||||
const std::uniform_real_distribution<double> uniform_zero_to_one;
|
||||
};
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string output_file, distribution_type;
|
||||
_u64 num_labels, num_points;
|
||||
|
||||
try {
|
||||
po::options_description desc{"Arguments"};
|
||||
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("output_file,O",
|
||||
po::value<std::string>(&output_file)->required(),
|
||||
"Filename for saving the label file");
|
||||
desc.add_options()("num_points,N",
|
||||
po::value<uint64_t>(&num_points)->required(),
|
||||
"Number of points in dataset");
|
||||
desc.add_options()("num_labels,L",
|
||||
po::value<uint64_t>(&num_labels)->required(),
|
||||
"Number of unique labels, up to 5000");
|
||||
desc.add_options()(
|
||||
"distribution_type,DT",
|
||||
po::value<std::string>(&distribution_type)->default_value("random"),
|
||||
"Distribution function for labels <random/zipf> defaults to random");
|
||||
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_labels > 5000) {
|
||||
std::cerr << "Error: num_labels must be 5000 or less" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (num_points <= 0) {
|
||||
std::cerr << "Error: num_points must be greater than 0" << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::cout << "Generating synthetic labels for " << num_points
|
||||
<< " points with " << num_labels << " unique labels" << '\n';
|
||||
|
||||
try {
|
||||
std::ofstream outfile(output_file);
|
||||
if (!outfile.is_open()) {
|
||||
std::cerr << "Error: could not open output file " << output_file << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (distribution_type == "zipf") {
|
||||
ZipfDistribution zipf(num_points, num_labels);
|
||||
zipf.writeDistribution(outfile);
|
||||
} else if (distribution_type == "random") {
|
||||
for (int i = 0; i < num_points; i++) {
|
||||
bool label_written = false;
|
||||
for (int j = 1; j <= num_labels; j++) {
|
||||
// 50% chance to assign each label
|
||||
if (rand() > (RAND_MAX / 2)) {
|
||||
if (label_written) {
|
||||
outfile << ',';
|
||||
}
|
||||
outfile << j;
|
||||
label_written = true;
|
||||
}
|
||||
}
|
||||
if (!label_written) {
|
||||
outfile << 0;
|
||||
}
|
||||
if (i < num_points - 1) {
|
||||
outfile << '\n';
|
||||
}
|
||||
}
|
||||
}
|
||||
if (outfile.is_open()) {
|
||||
outfile.close();
|
||||
}
|
||||
|
||||
std::cout << "Labels written to " << output_file << '\n';
|
||||
|
||||
} catch (const std::exception& ex) {
|
||||
std::cerr << "Label generation failed: " << ex.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,150 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
#include <atomic>
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <set>
|
||||
#include <boost/program_options.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#ifndef _WINDOWS
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/stat.h>
|
||||
#include <time.h>
|
||||
#else
|
||||
#include <Windows.h>
|
||||
#endif
|
||||
namespace po = boost::program_options;
|
||||
|
||||
void stats_analysis(const std::string labels_file, std::string univeral_label,
|
||||
_u32 density = 10) {
|
||||
std::string token, line;
|
||||
std::ifstream labels_stream(labels_file);
|
||||
std::unordered_map<std::string, _u32> label_counts;
|
||||
std::string label_with_max_points;
|
||||
_u32 max_points = 0;
|
||||
long long sum = 0;
|
||||
long long point_cnt = 0;
|
||||
float avg_labels_per_pt, avg_labels_per_pt_incl_0, mean_label_size,
|
||||
mean_label_size_incl_0;
|
||||
|
||||
std::vector<_u32> labels_per_point;
|
||||
_u32 dense_pts = 0;
|
||||
if (labels_stream.is_open()) {
|
||||
while (getline(labels_stream, line)) {
|
||||
point_cnt++;
|
||||
std::stringstream iss(line);
|
||||
_u32 lbl_cnt = 0;
|
||||
while (getline(iss, token, ',')) {
|
||||
lbl_cnt++;
|
||||
token.erase(std::remove(token.begin(), token.end(), '\n'), token.end());
|
||||
token.erase(std::remove(token.begin(), token.end(), '\r'), token.end());
|
||||
if (label_counts.find(token) == label_counts.end())
|
||||
label_counts[token] = 0;
|
||||
label_counts[token]++;
|
||||
}
|
||||
if (lbl_cnt >= density) {
|
||||
dense_pts++;
|
||||
}
|
||||
labels_per_point.emplace_back(lbl_cnt);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "fraction of dense points with >= " << density << " labels = "
|
||||
<< (float) dense_pts / (float) labels_per_point.size() << std::endl;
|
||||
std::sort(labels_per_point.begin(), labels_per_point.end());
|
||||
|
||||
std::vector<std::pair<std::string, _u32>> label_count_vec;
|
||||
|
||||
for (auto it = label_counts.begin(); it != label_counts.end(); it++) {
|
||||
auto& lbl = *it;
|
||||
label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second));
|
||||
if (lbl.second > max_points) {
|
||||
max_points = lbl.second;
|
||||
label_with_max_points = lbl.first;
|
||||
}
|
||||
sum += lbl.second;
|
||||
}
|
||||
|
||||
sort(label_count_vec.begin(), label_count_vec.end(),
|
||||
[](const std::pair<std::string, _u32>& lhs,
|
||||
const std::pair<std::string, _u32>& rhs) {
|
||||
return lhs.second < rhs.second;
|
||||
});
|
||||
|
||||
for (float p = 0; p < 1; p += 0.05) {
|
||||
std::cout << "Percentile " << (100 * p) << "\t"
|
||||
<< label_count_vec[(_u32) (p * label_count_vec.size())].first
|
||||
<< " with count="
|
||||
<< label_count_vec[(_u32) (p * label_count_vec.size())].second
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
std::cout << "Most common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 1].first
|
||||
<< " with count="
|
||||
<< label_count_vec[label_count_vec.size() - 1].second << std::endl;
|
||||
if (label_count_vec.size() > 1)
|
||||
std::cout << "Second common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 2].first
|
||||
<< " with count="
|
||||
<< label_count_vec[label_count_vec.size() - 2].second
|
||||
<< std::endl;
|
||||
if (label_count_vec.size() > 2)
|
||||
std::cout << "Third common label "
|
||||
<< "\t" << label_count_vec[label_count_vec.size() - 3].first
|
||||
<< " with count="
|
||||
<< label_count_vec[label_count_vec.size() - 3].second
|
||||
<< std::endl;
|
||||
avg_labels_per_pt = (sum) / (float) point_cnt;
|
||||
mean_label_size = (sum) / label_counts.size();
|
||||
std::cout << "Total number of points = " << point_cnt
|
||||
<< ", number of labels = " << label_counts.size() << std::endl;
|
||||
std::cout << "Average number of labels per point = " << avg_labels_per_pt
|
||||
<< std::endl;
|
||||
std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl;
|
||||
std::cout << "Most popular label is " << label_with_max_points << " with "
|
||||
<< max_points << " pts" << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::string labels_file, universal_label;
|
||||
_u32 density;
|
||||
|
||||
po::options_description desc{"Arguments"};
|
||||
try {
|
||||
desc.add_options()("help,h", "Print information on arguments");
|
||||
desc.add_options()("labels_file",
|
||||
po::value<std::string>(&labels_file)->required(),
|
||||
"path to labels data file.");
|
||||
desc.add_options()("universal_label",
|
||||
po::value<std::string>(&universal_label)->required(),
|
||||
"Universal label used in labels file.");
|
||||
desc.add_options()(
|
||||
"density", po::value<_u32>(&density)->default_value(1),
|
||||
"Number of labels each point in labels file, defaults to 1");
|
||||
po::variables_map vm;
|
||||
po::store(po::parse_command_line(argc, argv, desc), vm);
|
||||
if (vm.count("help")) {
|
||||
std::cout << desc;
|
||||
return 0;
|
||||
}
|
||||
po::notify(vm);
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << e.what() << '\n';
|
||||
return -1;
|
||||
}
|
||||
stats_analysis(labels_file, universal_label, density);
|
||||
}
|
|
@ -0,0 +1,126 @@
|
|||
**Usage for filtered indices**
|
||||
================================
|
||||
## Building a filtered Index
|
||||
DiskANN provides two algorithms for building an index with filters support: filtered-vamana and stitched-vamana. Here, we describe the parameters for building both. `tests/build_memory_index.cpp` and `tests/build_stitched_index.cpp` are respectively used to build each kind of index.
|
||||
|
||||
### 1. filtered-vamana
|
||||
|
||||
1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported.
|
||||
2. **`--dist_fn`**: There are two distance functions supported: minimum Euclidean distance (l2) and maximum inner product (mips).
|
||||
3. **`--data_file`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices.
|
||||
4. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix.
|
||||
5. **`-R (--max_degree)`** (default is 64): the degree of the graph index, typically between 32 and 150. Larger R will result in larger indices and longer indexing times, but might yield better search quality.
|
||||
6. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality. Note that this is to be used only for building an unfiltered index. The corresponding search list parameter for a filtered index is managed by `--FilteredLbuild`.
|
||||
7. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs.
|
||||
8. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth).
|
||||
9. **`--build_PQ_bytes`** (default is 0): Set to a positive value less than the dimensionality of the data to enable faster index build with PQ based distance comparisons. Defaults to using full precision vectors for distance comparisons.
|
||||
10. **`--use_opq`**: use the flag to use OPQ rather than PQ compression. OPQ is more space efficient for some high dimensional datasets, but also needs a bit more build time.
|
||||
11. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
|
||||
12. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point.
|
||||
13. **`--FilteredLbuild`**: If building a filtered index, we maintain a separate search list from the one provided by `--Lbuild`.
|
||||
|
||||
### 2. stitched-vamana
|
||||
1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported.
|
||||
2. **`--data_path`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices.
|
||||
3. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix.
|
||||
4. **`-R (--max_degree)`** (default is 64): Recall that stitched-vamana first builds a sub-index for each filter. This parameter sets the max degree for each sub-index.
|
||||
5. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during sub-index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality.
|
||||
6. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs.
|
||||
7. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth).
|
||||
8. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
|
||||
9. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point.
|
||||
10. **`--Stitched_R`**: Once all sub-indices are "stitched" together, we prune the resulting graph down to the degree given by this parameter.
|
||||
|
||||
## Computing a groundtruth file for a filtered index
|
||||
In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `tests/utils/compute_groundtruth.cpp` to provide the results for the latter:
|
||||
|
||||
1. **`--data_type`** The type of dataset you built an index with. float(32 bit), signed int8 and unsigned uint8 are supported.
|
||||
2. **`--dist_fn`**: There are two distance functions supported: l2 and mips.
|
||||
3. **`--base_file`**: The input data over which to build an index, in .bin format. Corresponds to the `--data_path` argument from above.
|
||||
4. **`--query_file`**: The queries to be searched on, which are stored in the same .bin format.
|
||||
5. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`.
|
||||
6. **`--filter_label`**: Filter for each query. For each query, a search is performed with this filter.
|
||||
7. **`--universal_label`**: Corresponds to the universal label passed when building an index with filter support.
|
||||
8. **`--gt_file`**: File to output results to. The binary file starts with `n`, the number of queries (4 bytes), followed by `d`, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the `d` closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes.
|
||||
9. **`-K`**: The number of nearest neighbors to compute for each query.
|
||||
|
||||
|
||||
|
||||
## Searching a Filtered Index
|
||||
|
||||
Searching a filtered index uses the `tests/search_memory_index.cpp`:
|
||||
|
||||
1. **`--data_type`**: The type of dataset you built the index on. float(32 bit), signed int8 and unsigned uint8 are supported. Use the same data type as in arg (1) above used in building the index.
|
||||
2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. There is an additional *fast_l2* implementation that could provide faster results for small (about a million-sized) indices. Use the same distance as in arg (2) above used in building the index. Note that stitched-vamana only supports l2.
|
||||
3. **`--index_path_prefix`**: index built above in argument (4).
|
||||
4. **`--result_path`**: search results will be stored in files, one per L value (see last arg), with specified prefix, in binary format.
|
||||
5. **`-T (--num_threads)`**: The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but may lead to higher per-query latency, especially if the DRAM bandwidth is a bottleneck. So find the balance depending on throughput and latency required for your application.
|
||||
6. **`--query_file`**: The queries to be searched on in same binary file format as the data file (ii) above. The query file must be the same type as in argument (1).
|
||||
7. **`--filter_label`**: The filter to be used when searching an index with filters. For each query, a search is performed with this filter.
|
||||
8. **`--gt_file`**: The ground truth file for the queries and data file used in index construction. Use "null" if you do not have this file and if you do not want to compute recall. Note that if building a filtered index, a special groundtruth must be computed, as described above.
|
||||
9. **`-K`**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors.
|
||||
10. **`-L (--search_list)`**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in (7).
|
||||
|
||||
Example with SIFT10K:
|
||||
--------------------
|
||||
We demonstrate how to work through this pipeline using the SIFT10K dataset (http://corpus-texmex.irisa.fr/). Before starting, make sure you have compiled diskANN according to the instructions in the README and can see the following binaries (paths with respect to repository root):
|
||||
- `build/tests/utils/compute_groundtruth`
|
||||
- `build/tests/utils/fvecs_to_bin`
|
||||
- `build/tests/build_memory_index`
|
||||
- `build/tests/build_stitched_index`
|
||||
- `build/tests/search_memory_index`
|
||||
|
||||
Now, download the base and query set and convert the data to binary format:
|
||||
```bash
|
||||
wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz
|
||||
tar -zxvf siftsmall.tar.gz
|
||||
build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin
|
||||
build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin
|
||||
```
|
||||
|
||||
We now need to make label file for our vectors. For convenience, we've included a synthetic label generator through which we can generate label file as follow
|
||||
```bash
|
||||
build/tests/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf
|
||||
```
|
||||
Note : `distribution_type` can be `rand` or `zipf`
|
||||
|
||||
This will genearate label file with 10000 data points with 50 distinct labels, ranging from 1 to 50 assigned using zipf distribution (0 is the universal label).
|
||||
|
||||
Label count for each unique label in the generated label file can be printed with help of following command
|
||||
```bash
|
||||
build/tests/utils/stats_label_data.exe --labels_file ./rand_labels_50_10K.txt --universal_label 0
|
||||
```
|
||||
|
||||
Note that neither approach is designed for use with random synthetic labels, which will lead to unpredictable accuracy at search time.
|
||||
|
||||
Now build and search the index and measure the recall using ground truth computed using bruteforce. We search for results with the filter 35.
|
||||
```bash
|
||||
build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --K 100 --label_file ./rand_labels_50_10K.txt --filter_label 35 --universal_label 0
|
||||
build/tests/build_memory_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R32_L50_filtered_index -R 32 --FilteredLbuild 50 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0
|
||||
build/tests/build_stitched_index --data_type float --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R20_L40_SR32_stitched_index -R 20 -L 40 --stitched_R 32 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0
|
||||
build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_filtered_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/filtered_search_results
|
||||
build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_stitched_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/stitched_search_results
|
||||
```
|
||||
|
||||
The output of both searches is listed below. The throughput (Queries/sec) as well as mean and 99.9 latency in microseconds for each `L` parameter provided. (Measured on a physical machine with a Intel(R) Xeon(R) W-2145 CPU and 64 GB RAM)
|
||||
```
|
||||
Stitched Index
|
||||
Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10
|
||||
=================================================================================
|
||||
10 31324.39 37.33 116.79 311.90 17.80
|
||||
20 91357.57 44.36 193.06 1042.30 17.90
|
||||
30 69314.48 49.89 258.09 1398.00 18.20
|
||||
40 61421.29 60.52 289.08 1515.00 18.60
|
||||
50 54203.48 70.27 294.26 685.10 19.40
|
||||
100 52904.45 79.00 336.26 1018.80 19.50
|
||||
|
||||
Filtered Index
|
||||
Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10
|
||||
=================================================================================
|
||||
10 69671.84 21.48 45.25 146.20 11.60
|
||||
20 168577.20 38.94 100.54 547.90 18.20
|
||||
30 127129.41 52.95 126.83 768.40 19.70
|
||||
40 106349.04 62.38 167.23 899.10 20.90
|
||||
50 89952.33 70.95 189.12 1070.80 22.10
|
||||
100 56899.00 112.26 304.67 636.60 23.80
|
||||
```
|
Загрузка…
Ссылка в новой задаче