[Draft] GPU unit and integration tests (#377)

* Initial unit tests and KNN build test

* Fix linking error

* Fix TPT tests

* Change test files and fix tpt test issues

* Fix linking issues

* Fix buildssd tests and add new tests - new bug with SPTAG logger when running tests

* Add benchmark tests for PQ optimization

---------

Co-authored-by: MaggieQi <chenqi871025@gmail.com>
This commit is contained in:
Ben Karsin 2024-05-03 12:49:51 -10:00 коммит произвёл GitHub
Родитель 9d7da6908a
Коммит 4ecf2495eb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
13 изменённых файлов: 1713 добавлений и 9 удалений

Просмотреть файл

@ -55,7 +55,7 @@ template<> __forceinline__ __host__ __device__ uint8_t INFTY<uint8_t>() {return
template<typename T> __device__ T BASE() {}
template<> __forceinline__ __device__ float BASE<float>() {return 1;}
template<> __forceinline__ __device__ int32_t BASE<int32_t>() {return 16384;}
template<> __forceinline__ __device__ int BASE<int>() {return 16384;}
template<> __forceinline__ __device__ uint32_t BASE<uint32_t>() {return 65536;}
@ -88,6 +88,24 @@ __device__ int32_t cosine_int8(int8_t* a, int8_t* b) {
return BASE<int32_t>() - prod;
}
template<int Dim>
__device__ float cosine_int8_rfloat(int8_t* a, int8_t* b) {
float prod=0;
float src=0;
float target=0;
uint32_t* newA = reinterpret_cast<uint32_t*>(a);
uint32_t* newB = reinterpret_cast<uint32_t*>(b);
for(int i=0; i<Dim/4; ++i) {
src = newA[i];
target = newB[i];
prod = (float)(__dp4a(src, target, (int32_t)prod));
}
return (float)(BASE<int32_t>()) - prod;
}
template<typename T, typename SUMTYPE, int Dim>
__forceinline__ __device__ SUMTYPE l2(T* aVec, T* bVec) {
SUMTYPE total[2]={0,0};
@ -103,6 +121,9 @@ template<typename T, typename SUMTYPE, int Dim, int metric>
__device__ SUMTYPE dist(T* a, T* b) {
if(metric == (int)DistMetric::Cosine) {
if(::cuda::std::is_same<T,int8_t>::value) {
// if(::cuda::std::is_same<SUMTYPE,float>::value) {
// return cosine_int8_rfloat<Dim>((int8_t*)a, (int8_t*)b);
// }
return cosine_int8<Dim>((int8_t*)a, (int8_t*)b);
}
return cosine<T,SUMTYPE,Dim>(a, b);

Просмотреть файл

@ -0,0 +1,128 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*
* Licensed under the MIT License.
*/
#ifndef _SPTAG_COMMON_CUDA_PERFTEST_H_
#define _SPTAG_COMMON_CUDA_PERFTEST_H_
#include "Refine.hxx"
#include "log.hxx"
#include "ThreadHeap.hxx"
#include "TPtree.hxx"
#include "GPUQuantizer.hxx"
#include <cuda/std/type_traits>
#include <chrono>
/***************************************************************************************
* Function called by SPTAG to create an initial graph on the GPU.
***************************************************************************************/
template<typename T>
void benchmarkDist(SPTAG::VectorIndex* index, int m_iGraphSize, int m_iNeighborhoodSize, int trees, int* results, int refines, int refineDepth, int graph, int leafSize, int initSize, int NUM_GPUS, int balanceFactor) {
int m_disttype = (int)index->GetDistCalcMethod();
size_t dataSize = (size_t)m_iGraphSize;
int KVAL = m_iNeighborhoodSize;
int dim = index->GetFeatureDim();
size_t rawSize = dataSize*dim;
if(index->m_pQuantizer != NULL) {
SPTAG::COMMON::PQQuantizer<int>* pq_quantizer = (SPTAG::COMMON::PQQuantizer<int>*)index->m_pQuantizer.get();
dim = pq_quantizer->GetNumSubvectors();
}
// srand(time(NULL)); // random number seed for TP tree random hyperplane partitions
srand(1); // random number seed for TP tree random hyperplane partitions
/*******************************
* Error checking
********************************/
int numDevicesOnHost;
CUDA_CHECK(cudaGetDeviceCount(&numDevicesOnHost));
if(numDevicesOnHost < NUM_GPUS) {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "HeadNumGPUs parameter %d, but only %d devices available on system. Exiting.\n", NUM_GPUS, numDevicesOnHost);
exit(1);
}
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Building Head graph with %d GPUs...\n", NUM_GPUS);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Total of %d GPU devices on system, using %d of them.\n", numDevicesOnHost, NUM_GPUS);
/**** Compute result batch sizes for each GPU ****/
std::vector<size_t> batchSize(NUM_GPUS);
std::vector<size_t> resPerGPU(NUM_GPUS);
for(int gpuNum=0; gpuNum < NUM_GPUS; ++gpuNum) {
CUDA_CHECK(cudaSetDevice(gpuNum));
resPerGPU[gpuNum] = dataSize / NUM_GPUS; // Results per GPU
if(dataSize % NUM_GPUS > gpuNum) resPerGPU[gpuNum]++;
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, gpuNum)); // Get avil. memory
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU %d - %s\n", gpuNum, prop.name);
size_t freeMem, totalMem;
CUDA_CHECK(cudaMemGetInfo(&freeMem, &totalMem));
size_t rawDataSize, pointSetSize, treeSize, resMemAvail, maxEltsPerBatch;
rawDataSize = rawSize*sizeof(T);
pointSetSize = sizeof(PointSet<T>);
treeSize = 20*dataSize;
resMemAvail = (freeMem*0.9) - (rawDataSize+pointSetSize+treeSize); // Only use 90% of total memory to be safe
maxEltsPerBatch = resMemAvail / (dim*sizeof(T) + KVAL*sizeof(int));
batchSize[gpuNum] = (std::min)(maxEltsPerBatch, resPerGPU[gpuNum]);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "Memory for rawData:%lu MiB, pointSet structure:%lu MiB, Memory for TP trees:%lu MiB, Memory left for results:%lu MiB, total vectors:%lu, batch size:%d, total batches:%d\n", rawSize/1000000, pointSetSize/1000000, treeSize/1000000, resMemAvail/1000000, resPerGPU[gpuNum], batchSize[gpuNum], (((batchSize[gpuNum]-1)+resPerGPU[gpuNum]) / batchSize[gpuNum]));
// If GPU memory is insufficient or so limited that we need so many batches it becomes inefficient, return error
if(batchSize[gpuNum] == 0 || ((int)resPerGPU[gpuNum]) / batchSize[gpuNum] > 10000) {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Insufficient GPU memory to build Head index on GPU %d. Available GPU memory:%lu MB, Points and tpt require:%lu MB, leaving a maximum batch size of %d results to be computed, which is too small to run efficiently.\n", gpuNum, (freeMem)/1000000, (rawDataSize+pointSetSize+treeSize)/1000000, maxEltsPerBatch);
exit(1);
}
}
std::vector<size_t> GPUOffset(NUM_GPUS);
GPUOffset[0] = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU 0: results:%lu, offset:%lu\n", resPerGPU[0], GPUOffset[0]);
for(int gpuNum=1; gpuNum < NUM_GPUS; ++gpuNum) {
GPUOffset[gpuNum] = GPUOffset[gpuNum-1] + resPerGPU[gpuNum-1];
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Debug, "GPU %d: results:%lu, offset:%lu\n", gpuNum, resPerGPU[gpuNum], GPUOffset[gpuNum]);
}
if(index->m_pQuantizer != NULL) {
buildGraphGPU<uint8_t, float>(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim);
}
else if(typeid(T) == typeid(float)) {
buildGraphGPU<T, float>(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim);
}
else if(typeid(T) == typeid(uint8_t) || typeid(T) == typeid(int8_t)) {
buildGraphGPU<T, int32_t>(index, (size_t)m_iGraphSize, KVAL, trees, results, graph, leafSize, NUM_GPUS, balanceFactor, batchSize.data(), GPUOffset.data(), resPerGPU.data(), dim);
}
else {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Selected datatype not currently supported.\n");
exit(1);
}
}
#endif

Просмотреть файл

@ -161,13 +161,13 @@ class TPtree {
tree_mem+= levels*sizeof(int*) + levels*Dim*sizeof(KEYTYPE);
tree_mem+= N*sizeof(int);
CUDA_CHECK(cudaMallocManaged(&node_sizes, num_nodes*sizeof(int)));
CUDA_CHECK(cudaMalloc(&node_sizes, num_nodes*sizeof(int)));
CUDA_CHECK(cudaMemset(node_sizes, 0, num_nodes*sizeof(int)));
CUDA_CHECK(cudaMallocManaged(&split_keys, num_internals*sizeof(KEYTYPE)));
CUDA_CHECK(cudaMalloc(&split_keys, num_internals*sizeof(KEYTYPE)));
tree_mem+= num_nodes*sizeof(int) + num_internals*sizeof(KEYTYPE);
CUDA_CHECK(cudaMallocManaged(&leafs, num_leaves*sizeof(LeafNode)));
CUDA_CHECK(cudaMalloc(&leafs, num_leaves*sizeof(LeafNode)));
tree_mem+=num_leaves*sizeof(LeafNode);
CUDA_CHECK(cudaMalloc(&leaf_points, N*sizeof(int)));
@ -206,7 +206,6 @@ class TPtree {
// For debugging purposes
************************************************************************************/
__host__ void print_tree() {
printf("nodes:%d, leaves:%d, levels:%d\n", num_nodes, num_leaves, levels);
int level_offset;
print_level_device<<<1,1>>>(node_sizes, split_keys, 1, leafs, leaf_points);
@ -466,7 +465,7 @@ __host__ void create_tptree_multigpu(TPtree** d_trees, PointSet<T>** ps, int N,
// Build TPT on each GPU
// construct_trees_multigpu<T>(d_trees, ps, N, NUM_GPUS, streams, balanceFactor);
if(index->m_pQuantizer == NULL) { // Build directly if no quantizer
if(index == NULL || index->m_pQuantizer == NULL) { // Build directly if no quantizer
construct_trees_multigpu<T>(d_trees, ps, N, NUM_GPUS, streams, balanceFactor);
}
else {

Просмотреть файл

@ -336,7 +336,7 @@ __host__ void get_query_groups(QueryGroup* groups, TPtree* tptree, PointSet<T>*
}
#define MAX_SHAPE 384
#define MAX_SHAPE 1024
template<typename T, typename SUMTYPE, int metric>
__global__ void findTailNeighbors_selector(PointSet<T>* headPS, PointSet<T>* tailPS, TPtree* tptree, int KVAL, DistPair<SUMTYPE>* results, size_t curr_batch_size, size_t numHeads, QueryGroup* groups, int dim) {
@ -346,6 +346,7 @@ __global__ void findTailNeighbors_selector(PointSet<T>* headPS, PointSet<T>* tai
RUN_TAIL_KERNEL(64)
RUN_TAIL_KERNEL(100)
RUN_TAIL_KERNEL(384)
RUN_TAIL_KERNEL(MAX_SHAPE)
}

Просмотреть файл

@ -28,6 +28,7 @@ if (CUDA_FOUND)
set(AnnService ${PROJECT_SOURCE_DIR}/AnnService)
include_directories(${AnnService})
include_directories(${PROJECT_SOURCE_DIR}/Test/cuda)
include_directories(${PROJECT_SOURCE_DIR}/ThirdParty/zstd/lib)
@ -38,12 +39,11 @@ if (CUDA_FOUND)
${AnnService}/inc/Core/Common/DistanceUtils.h
${AnnService}/inc/Core/Common/InstructionUtils.h
${AnnService}/inc/Core/Common/CommonUtils.h
${AnnService}/inc/Core/Common/cuda/tests/
)
list(REMOVE_ITEM GPU_SRC_FILES
${AnnService}/src/Core/Common/DistanceUtils.cpp
${AnnService}/src/Core/Common/InstructionUtils.cpp
${AnnService}/src/Core/Common/cuda/tests/
${AnnService}/Test/cuda
)
set_source_files_properties(${GPU_SRC_FILES} PROPERTIES CUDA_SOURCE_PROPERTY_FORMAT OBJ)
@ -83,6 +83,19 @@ if (CUDA_FOUND)
CUDA_ADD_EXECUTABLE(gpussdserving ${SSD_SERVING_HDR_FILES} ${SSD_SERVING_FILES})
target_link_libraries(gpussdserving GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES})
target_compile_definitions(gpussdserving PRIVATE ${Definition} _exe)
CUDA_ADD_LIBRARY(GPUSPTAGTests SHARED ${AnnService}/../Test/cuda/knn_tests.cu ${AnnService}/../Test/cuda/distance_tests.cu ${AnnService}/../Test/cuda/tptree_tests.cu ${AnnService}/../Test/cuda/buildssd_test.cu ${AnnService}/../Test/cuda/gpu_pq_perf.cu)
target_link_libraries(GPUSPTAGTests ${Boost_LIBRARIES} ${CUDA_LIBRARIES})
target_compile_definitions(GPUSPTAGTests PRIVATE ${Definition})
CUDA_ADD_EXECUTABLE(gpu_test ${AnnService}/../Test/cuda/cuda_tests.cpp)
target_link_libraries(gpu_test GPUSPTAGTests GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES})
target_compile_definitions(gpu_test PRIVATE ${Definition} _exe)
CUDA_ADD_EXECUTABLE(gpu_pq_test ${AnnService}/../Test/cuda/pq_perf.cpp)
target_link_libraries(gpu_pq_test GPUSPTAGTests GPUSPTAGLibStatic ${Boost_LIBRARIES} ${CUDA_LIBRARIES})
target_compile_definitions(gpu_pq_test PRIVATE ${Definition} _exe)
else()
message (STATUS "Could not find cuda.")
endif()

128
Test/cuda/buildssd_test.cu Normal file
Просмотреть файл

@ -0,0 +1,128 @@
#include "common.hxx"
#include "inc/Core/Common/cuda/TailNeighbors.hxx"
template<typename T, typename SUMTYPE, int dim, int K>
int GPUBuildSSDTest(int rows, int metric, int iters);
int GPUBuildSSDTest_All() {
int errors = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting Cosine BuildSSD tests\n");
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n");
errors += GPUBuildSSDTest<float, float, 10, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<float, float, 100, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<float, float, 200, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<float, float, 384, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<float, float, 1024, 8>(10000, (int)DistMetric::Cosine, 10);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int datatype...\n");
errors += GPUBuildSSDTest<int, int, 10, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int, int, 100, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int, int, 200, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int, int, 384, 8>(10000, (int)DistMetric::Cosine, 10);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype...\n");
errors += GPUBuildSSDTest<int8_t, int32_t, 100, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 200, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 384, 8>(10000, (int)DistMetric::Cosine, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 1024, 8>(10000, (int)DistMetric::Cosine, 10);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting L2 BuildSSD tests\n");
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n");
errors += GPUBuildSSDTest<float, float, 100, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<float, float, 200, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<float, float, 384, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<float, float, 1024, 8>(10000, (int)DistMetric::L2, 10);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int datatype...\n");
errors += GPUBuildSSDTest<int, int, 10, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int, int, 100, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int, int, 200, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int, int, 384, 8>(10000, (int)DistMetric::L2, 10);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype...\n");
errors += GPUBuildSSDTest<int8_t, int32_t, 100, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 200, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 384, 8>(10000, (int)DistMetric::L2, 10);
errors += GPUBuildSSDTest<int8_t, int32_t, 1024, 8>(10000, (int)DistMetric::L2, 10);
return errors;
}
template<typename T, typename SUMTYPE, int dim, int K>
int GPUBuildSSDTest(int rows, int metric, int iters) {
int errors = 0;
int num_heads = rows/10;
// Create random data for head vectors
T* head_data = create_dataset<T>(num_heads, dim);
T* d_head_data;
CUDA_CHECK(cudaMalloc(&d_head_data, dim*num_heads*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_head_data, head_data, dim*num_heads*sizeof(T), cudaMemcpyHostToDevice));
PointSet<T> h_head_ps;
h_head_ps.dim = dim;
h_head_ps.data = d_head_data;
PointSet<T>* d_head_ps;
CUDA_CHECK(cudaMalloc(&d_head_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_head_ps, &h_head_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
// Create random data for tail vectors
T* tail_data = create_dataset<T>(rows, dim);
T* d_tail_data;
CUDA_CHECK(cudaMalloc(&d_tail_data, dim*rows*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_tail_data, tail_data, dim*rows*sizeof(T), cudaMemcpyHostToDevice));
PointSet<T> h_tail_ps;
h_tail_ps.dim = dim;
h_tail_ps.data = d_tail_data;
PointSet<T>* d_tail_ps;
CUDA_CHECK(cudaMalloc(&d_tail_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_tail_ps, &h_tail_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
DistPair<SUMTYPE>* d_results;
CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(DistPair<SUMTYPE>)));
int TPTlevels = (int)std::log2(num_heads/100);
TPtree* h_tree = new TPtree;
h_tree->initialize(num_heads, TPTlevels, dim);
TPtree* d_tree;
CUDA_CHECK(cudaMalloc(&d_tree, sizeof(TPtree)));
// Alloc memory for QuerySet structure
int* d_queryMem;
CUDA_CHECK(cudaMalloc(&d_queryMem, rows*sizeof(int) + 2*h_tree->num_leaves*sizeof(int)));
QueryGroup* d_queryGroups;
CUDA_CHECK(cudaMalloc(&d_queryGroups, sizeof(QueryGroup)));
cudaStream_t stream;
CUDA_CHECK(cudaStreamCreate(&stream));
for(int i=0; i<iters; ++i) {
create_tptree_multigpu<T>(&h_tree, &d_head_ps, num_heads, TPTlevels, 1, &stream, 2, NULL);
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemcpy(d_tree, h_tree, sizeof(TPtree), cudaMemcpyHostToDevice));
get_query_groups<T,SUMTYPE>(d_queryGroups, d_tree, d_tail_ps, rows, (int)h_tree->num_leaves, d_queryMem, 1024, 32, dim);
CUDA_CHECK(cudaDeviceSynchronize());
if(metric == (int)DistMetric::Cosine) {
findTailNeighbors_selector<T,SUMTYPE,(int)DistMetric::Cosine><<<1024, 32, sizeof(DistPair<SUMTYPE>)*32*K>>>(d_head_ps, d_tail_ps, d_tree, K, d_results, rows, num_heads, d_queryGroups, dim);
}
else {
findTailNeighbors_selector<T,SUMTYPE,(int)DistMetric::L2><<<1024, 32, sizeof(DistPair<SUMTYPE>)*32*K>>>(d_head_ps, d_tail_ps, d_tree, K, d_results, rows, num_heads, d_queryGroups, dim);
}
CUDA_CHECK(cudaDeviceSynchronize());
}
CUDA_CHECK(cudaFree(d_head_data));
CUDA_CHECK(cudaFree(d_head_ps));
CUDA_CHECK(cudaFree(d_tail_data));
CUDA_CHECK(cudaFree(d_tail_ps));
CUDA_CHECK(cudaFree(d_results));
h_tree->destroy();
CUDA_CHECK(cudaFree(d_tree));
return errors;
}

59
Test/cuda/common.hxx Normal file
Просмотреть файл

@ -0,0 +1,59 @@
//#include "inc/Core/Common/cuda/KNN.hxx"
#include <cstdlib>
#include <chrono>
#define CHECK_ERRS(errs) \
if(errs > 0) { \
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "%d errors found\n", errs); \
}
#define CHECK_VAL(val,exp,errs) \
if(val != exp) { \
errs++; \
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s != %s\n",#val,#exp); \
}
#define CHECK_VAL_LT(val,exp,errs) \
if(val > exp) { \
errs++; \
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s > %s\n",#val,#exp); \
}
#define GPU_CHECK_VAL(val,exp,dtype,errs) \
dtype temp; \
CUDA_CHECK(cudaMemcpy(&temp, val, sizeof(dtype), cudaMemcpyDeviceToHost)); \
float eps = 0.01; \
if((float)temp>0.0 && ((float)temp*(1.0+eps) < (float)(exp) || (float)temp*(1.0-eps) > (float)(exp))) { \
errs++; \
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "%s != %s\n",#val,#exp); \
}
template<typename T>
T* create_dataset(size_t rows, int dim) {
srand(0);
T* h_data = new T[rows*dim];
for(size_t i=0; i<rows*dim; ++i) {
if(std::is_same<T,float>::value) {
h_data[i] = (rand()/(float)RAND_MAX);
}
else if(std::is_same<T,int>::value) {
h_data[i] = static_cast<T>((rand()%INT_MAX));
}
else if(std::is_same<T,uint8_t>::value) {
h_data[i] = static_cast<T>((rand()%127));
}
else if(std::is_same<T,int8_t>::value) {
h_data[i] = static_cast<T>((rand()%127));
}
}
return h_data;
}
/*
__global__ void count_leaf_sizes(LeafNode* leafs, int* node_ids, int N, int internal_nodes);
__global__ void assign_leaf_points_in_batch(LeafNode* leafs, int* leaf_points, int* node_ids, int N, int internal_nodes, int min_id, int max_id);
__global__ void assign_leaf_points_out_batch(LeafNode* leafs, int* leaf_points, int* node_ids, int N, int internal_nodes, int min_id, int max_id);
__global__ void compute_mean(KEYTYPE* split_keys, int* node_sizes, int num_nodes);
__global__ void initialize_rands(curandState* states, int iter);
*/

43
Test/cuda/cuda_tests.cpp Normal file
Просмотреть файл

@ -0,0 +1,43 @@
//#include "test_kernels.cu"
#define BOOST_TEST_MODULE GPU
#include <cstdlib>
#include <chrono>
#include <iostream>
#include <boost/test/included/unit_test.hpp>
#include <boost/filesystem.hpp>
int GPUBuildKNNTest();
BOOST_AUTO_TEST_CASE(RandomTests) {
BOOST_CHECK(1 == 1);
int errors = GPUBuildKNNTest();
printf("outside\n");
BOOST_CHECK(errors == 0);
}
/*
int GPUTestDistance_All();
BOOST_AUTO_TEST_CASE(DistanceTests) {
int errs = GPUTestDistance_All();
BOOST_CHECK(errs == 0);
}
int GPUBuildTPTTest();
BOOST_AUTO_TEST_CASE(TPTreeTests) {
int errs = GPUBuildTPTTest();
BOOST_CHECK(errs == 0);
}
int GPUBuildSSDTest_All();
BOOST_AUTO_TEST_CASE(BuildSSDTests) {
int errs = GPUBuildSSDTest_All();
BOOST_CHECK(errs == 0);
}
*/

202
Test/cuda/distance_tests.cu Normal file
Просмотреть файл

@ -0,0 +1,202 @@
#include "common.hxx"
#include "inc/Core/Common/cuda/KNN.hxx"
#define GPU_CHECK_CORRECT(a,b,msg) \
if(a != (SUMTYPE)b) { \
printf(msg); \
errs = 0; \
return; \
}
template<typename T, typename SUMTYPE, int Dim>
__global__ void GPUTestDistancesKernelStatic(PointSet<T>* ps, int errs) {
SUMTYPE ab, ac, bc;
// Expected results
SUMTYPE l2_res[3] = {Dim, 4*Dim, Dim};
SUMTYPE cosine_res[3] = {BASE<SUMTYPE>(), BASE<SUMTYPE>(), BASE<SUMTYPE>()-2*Dim};
// l2 dist
ab = dist<T, SUMTYPE, Dim, (int)DistMetric::L2>(ps->getVec(0), ps->getVec(1));
ac = dist<T, SUMTYPE, Dim, (int)DistMetric::L2>(ps->getVec(0), ps->getVec(2));
bc = dist<T, SUMTYPE, Dim, (int)DistMetric::L2>(ps->getVec(1), ps->getVec(2));
GPU_CHECK_CORRECT(ab,l2_res[0],"Static L2 distance check failed\n");
GPU_CHECK_CORRECT(ac,l2_res[1], "Static L2 distance check failed\n");
GPU_CHECK_CORRECT(bc,l2_res[2], "Static L2 distance check failed\n");
// cosine dist
ab = dist<T, SUMTYPE, Dim, (int)DistMetric::Cosine>(ps->getVec(0), ps->getVec(1));
ac = dist<T, SUMTYPE, Dim, (int)DistMetric::Cosine>(ps->getVec(0), ps->getVec(2));
bc = dist<T, SUMTYPE, Dim, (int)DistMetric::Cosine>(ps->getVec(1), ps->getVec(2));
GPU_CHECK_CORRECT(ab,cosine_res[0],"Static Cosine distance check failed\n");
GPU_CHECK_CORRECT(ac,cosine_res[1],"Static Cosine distance check failed\n");
GPU_CHECK_CORRECT(bc,cosine_res[2],"Static Cosine distance check failed\n");
}
template<typename T, typename SUMTYPE, int dim>
int GPUTestDistancesSimple() {
T* h_data = new T[dim*3];
for(int i=0; i<3; ++i) {
for(int j=0; j<dim; ++j) {
h_data[i*dim+j] = (T)(i);
}
}
T* d_data;
CUDA_CHECK(cudaMalloc(&d_data, dim*3*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_data, h_data, dim*3*sizeof(T), cudaMemcpyHostToDevice));
PointSet<T> h_ps;
h_ps.dim = dim;
h_ps.data = d_data;
PointSet<T>* d_ps;
CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
int errs=0;
GPUTestDistancesKernelStatic<T,SUMTYPE,dim><<<1, 1>>>(d_ps, errs); // TODO - make sure result is correct returned
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaFree(d_data));
CUDA_CHECK(cudaFree(d_ps));
return errs;
}
template<typename T, typename SUMTYPE, int dim>
__global__ void GPUTestDistancesRandomKernel(PointSet<T>* ps, int vecs, SUMTYPE* cosine_dists, SUMTYPE* l2_dists, int errs) {
SUMTYPE cos, l2;
SUMTYPE diff;
float eps = 0.001; // exact distances are not expected, but should be within an epsilon
for(int i=blockIdx.x*blockDim.x + threadIdx.x; i<vecs; i+= gridDim.x*blockDim.x) {
for(int j=0; j<vecs; ++j) {
cos = dist<T,SUMTYPE,dim, (int)DistMetric::Cosine>(ps->getVec(i), ps->getVec(j));
if( (((float)(cos - (SUMTYPE)cosine_dists[i*vecs+j]))/(float)cos) > eps) {
printf("Error, cosine calculations differ by too much i:%d, j:%d - GPU:%u, CPU:%u, diff:%f\n", i, j, cos, (SUMTYPE)cosine_dists[i*vecs+j], (((float)(cos - (SUMTYPE)cosine_dists[i*vecs+j]))/(float)cos));
return;
}
l2 = dist<T,SUMTYPE,dim, (int)DistMetric::L2>(ps->getVec(i), ps->getVec(j));
if( (((float)(l2 - (SUMTYPE)l2_dists[i*vecs+j]))/(float)l2) > eps) {
printf("Error, l2 calculations differ by too much i:%d, j:%d - GPU:%d, CPU:%d\n", i, j, l2, (int)l2_dists[i*vecs+j]);
return;
}
}
}
}
template<typename T, typename SUMTYPE, int dim>
int GPUTestDistancesComplex(int vecs) {
srand(time(NULL));
T* h_data = new T[vecs*dim];
for(int i=0; i<vecs; ++i) {
for(int j=0; j<dim; ++j) {
if(std::is_same<T,float>::value) {
h_data[i*dim+j] = (rand()/(float)RAND_MAX);
}
else if(std::is_same<T,int>::value) {
h_data[i*dim+j] = static_cast<T>((rand()%INT_MAX));
}
else if(std::is_same<T,uint8_t>::value) {
h_data[i*dim+j] = static_cast<T>((rand()%127));
}
else if(std::is_same<T,int8_t>::value) {
h_data[i*dim+j] = static_cast<T>((rand()%127));
}
}
}
// Compute CPU distances to verify with GPU metric
SUMTYPE* cpu_cosine_dists = new SUMTYPE[vecs*vecs];
SUMTYPE* cpu_l2_dists = new SUMTYPE[vecs*vecs];
for(int i=0; i<vecs; ++i) {
for(int j=0; j<vecs; ++j) {
cpu_cosine_dists[i*vecs+j] = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeCosineDistance<T>(&h_data[i*dim], &h_data[j*dim], dim));
cpu_l2_dists[i*vecs+j] = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeL2Distance<T>(&h_data[i*dim], &h_data[j*dim], dim));
}
}
int errs=0;
T* d_data;
CUDA_CHECK(cudaMalloc(&d_data, dim*vecs*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_data, h_data, dim*vecs*sizeof(T), cudaMemcpyHostToDevice));
PointSet<T> h_ps;
h_ps.dim = dim;
h_ps.data = d_data;
PointSet<T>* d_ps;
CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
SUMTYPE* d_cosine_dists;
CUDA_CHECK(cudaMalloc(&d_cosine_dists, vecs*vecs*sizeof(SUMTYPE)));
CUDA_CHECK(cudaMemcpy(d_cosine_dists, cpu_cosine_dists, vecs*vecs*sizeof(SUMTYPE), cudaMemcpyHostToDevice));
SUMTYPE* d_l2_dists;
CUDA_CHECK(cudaMalloc(&d_l2_dists, vecs*vecs*sizeof(SUMTYPE)));
CUDA_CHECK(cudaMemcpy(d_l2_dists, cpu_l2_dists, vecs*vecs*sizeof(SUMTYPE), cudaMemcpyHostToDevice));
GPUTestDistancesRandomKernel<T,SUMTYPE,dim><<<1024, 32>>>(d_ps, vecs, d_cosine_dists, d_l2_dists, errs); // TODO - make sure result is correct returned
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaFree(d_data));
CUDA_CHECK(cudaFree(d_ps));
CUDA_CHECK(cudaFree(d_cosine_dists));
CUDA_CHECK(cudaFree(d_l2_dists));
return errs;
}
int GPUTestDistance_All() {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Static distance tests...\n");
// Test Distances with float datatype
int errs = 0;
errs += GPUTestDistancesSimple<float, float, 10>();
errs += GPUTestDistancesSimple<float, float, 100>();
errs += GPUTestDistancesSimple<float, float, 200>();
errs += GPUTestDistancesSimple<float, float, 384>();
errs += GPUTestDistancesSimple<float, float, 1024>();
// Test distances with int datatype
errs += GPUTestDistancesSimple<int, int, 10>();
errs += GPUTestDistancesSimple<int, int, 100>();
errs += GPUTestDistancesSimple<int, int, 200>();
errs += GPUTestDistancesSimple<int, int, 384>();
errs += GPUTestDistancesSimple<int, int, 1024>();
// Test distances with int8 datatype
errs += GPUTestDistancesSimple<int8_t, int32_t, 100>();
errs += GPUTestDistancesSimple<int8_t, int32_t, 200>();
errs += GPUTestDistancesSimple<int8_t, int32_t, 384>();
errs += GPUTestDistancesSimple<int8_t, int32_t, 1024>();
CHECK_ERRS(errs)
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Randomized vector distance tests...\n");
// Test distances between random vectors and compare with CPU calculation
errs += GPUTestDistancesComplex<float, float, 10>(100);
errs += GPUTestDistancesComplex<float, float, 100>(100);
errs += GPUTestDistancesComplex<float, float, 200>(100);
errs += GPUTestDistancesComplex<float, float, 384>(100);
errs += GPUTestDistancesComplex<float, float, 1024>(100);
errs += GPUTestDistancesComplex<int, int, 10>(100);
errs += GPUTestDistancesComplex<int, int, 100>(100);
errs += GPUTestDistancesComplex<int, int, 200>(100);
errs += GPUTestDistancesComplex<int, int, 384>(100);
errs += GPUTestDistancesComplex<int, int, 1024>(100);
errs += GPUTestDistancesComplex<int8_t, int32_t, 100>(100);
errs += GPUTestDistancesComplex<int8_t, int32_t, 200>(100);
errs += GPUTestDistancesComplex<int8_t, int32_t, 384>(100);
errs += GPUTestDistancesComplex<int8_t, int32_t, 1024>(100);
CHECK_ERRS(errs)
return errs;
}

135
Test/cuda/gpu_pq_perf.cu Normal file
Просмотреть файл

@ -0,0 +1,135 @@
#include "common.hxx"
#include "inc/Core/Common/cuda/KNN.hxx"
template<typename T, int Dim, int metric>
__global__ void top1_nopq_kernel(T* data, int* res_idx, float* results, int datasize) {
float temp;
for(int i=blockIdx.x*blockDim.x + threadIdx.x; i<datasize; i+=blockDim.x*gridDim.x) {
results[i] = INFTY<T>();
for(int j=0; j<datasize; ++j) {
if(j != i) {
temp = dist<T,T,Dim,metric>(&data[i*Dim], &data[j*Dim]);
if(temp < results[i]) {
results[i] = temp;
res_idx[i] = j;
}
}
}
}
}
#define TEST_BLOCKS 512
#define TEST_THREADS 32
template<typename R>
void GPU_top1_nopq(std::shared_ptr<VectorSet>& real_vecset, DistCalcMethod distMethod, int* res_idx, float* res_dist) {
R* d_data;
CUDA_CHECK(cudaMalloc(&d_data, real_vecset->Count()*real_vecset->Dimension()*sizeof(R)));
CUDA_CHECK(cudaMemcpy(d_data, reinterpret_cast<R*>(real_vecset->GetVector(0)), real_vecset->Count()*real_vecset->Dimension()*sizeof(R), cudaMemcpyHostToDevice));
// float* nearest = new float[real_vecset->Count()];
float* d_res_dist;
CUDA_CHECK(cudaMalloc(&d_res_dist, real_vecset->Count()*sizeof(float)));
int* d_res_idx;
CUDA_CHECK(cudaMalloc(&d_res_idx, real_vecset->Count()*sizeof(int)));
// Run kernel that performs
// Create options for different dims and metrics
if(real_vecset->Dimension() == 256) {
top1_nopq_kernel<R,256,(int)DistMetric::L2><<<TEST_BLOCKS,TEST_THREADS>>>(d_data, d_res_idx, d_res_dist, real_vecset->Count());
}
else {
printf("Add support for testing with %d dimensions\n", real_vecset->Dimension());
}
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemcpy(res_idx, d_res_idx, real_vecset->Count()*sizeof(int), cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(res_dist, d_res_dist, real_vecset->Count()*sizeof(float), cudaMemcpyDeviceToHost));
// Optional check results for debugging
}
void GPU_nopq_alltype(std::shared_ptr<VectorSet>& real_vecset, DistCalcMethod distMethod, int* res_idx, float* res_dist) {
GPU_top1_nopq<float>(real_vecset, distMethod, res_idx, res_dist);
}
template<typename T, int Dim>
__global__ void top1_pq_kernel(uint8_t* data, int* res_idx, float* res_dist, int datasize, GPU_Quantizer* quantizer) {
if(threadIdx.x==0 && blockIdx.x==0) {
printf("Quantizer numSub:%d, KsPerSub:%ld, BlockSize:%ld, dimPerSub:%ld\n", quantizer->m_NumSubvectors, quantizer->m_KsPerSubvector, quantizer->m_BlockSize, quantizer->m_DimPerSubvector);
printf("dataSize:%d, data: %u, %u\n", datasize, data[0] & 255, data[1] & 255);
}
float temp;
for(int i=blockIdx.x*blockDim.x + threadIdx.x; i<datasize; i+=blockDim.x*gridDim.x) {
res_dist[i] = INFTY<T>();
for(int j=0; j<datasize; ++j) {
if(i != j) {
temp = quantizer->dist(&data[i*Dim], &data[j*Dim]);
if(temp < res_dist[i]) {
res_dist[i] = temp;
res_idx[i] = j;
if(i==0) printf("j:%d, dist:%f\n", j, temp);
}
}
}
}
}
template<typename R>
void GPU_top1_pq(std::shared_ptr<VectorSet>& real_vecset, std::shared_ptr<VectorSet>& quan_vecset,DistCalcMethod distMethod, std::shared_ptr<COMMON::IQuantizer>& quantizer, int* res_idx, float* res_dist) {
printf("Running GPU PQ - PQ dims:%d\n", quan_vecset->Dimension());
std::shared_ptr<VectorIndex> vecIndex = SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType<uint8_t>());
vecIndex->SetQuantizer(quantizer);
vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod));
GPU_Quantizer* d_quantizer = NULL;
GPU_Quantizer* h_quantizer = NULL;
printf("QTYPE:%d\n", (int)(quantizer->GetQuantizerType()));
printf("Creating GPU_Quantizer\n");
h_quantizer = new GPU_Quantizer(quantizer, DistMetric::L2); // TODO - add other metric option
CUDA_CHECK(cudaMalloc(&d_quantizer, sizeof(GPU_Quantizer)));
CUDA_CHECK(cudaMemcpy(d_quantizer, h_quantizer, sizeof(GPU_Quantizer), cudaMemcpyHostToDevice));
uint8_t* d_data;
CUDA_CHECK(cudaMalloc(&d_data, quan_vecset->Count()*quan_vecset->Dimension()*sizeof(uint8_t)));
CUDA_CHECK(cudaMemcpy(d_data, reinterpret_cast<uint8_t*>(quan_vecset->GetVector(0)), quan_vecset->Count()*quan_vecset->Dimension()*sizeof(uint8_t), cudaMemcpyHostToDevice));
// float* nearest = new float[quan_vecset->Count()];
// float* d_nearest;
// CUDA_CHECK(cudaMalloc(&d_nearest, quan_vecset->Count()*sizeof(float)));
float* d_res_dist;
CUDA_CHECK(cudaMalloc(&d_res_dist, quan_vecset->Count()*sizeof(float)));
int* d_res_idx;
CUDA_CHECK(cudaMalloc(&d_res_idx, quan_vecset->Count()*sizeof(int)));
// Run kernel that performs
// Create options for different dims and metrics
if(quan_vecset->Dimension() == 128) {
top1_pq_kernel<R,128><<<TEST_BLOCKS,TEST_THREADS>>>(d_data, d_res_idx, d_res_dist, quan_vecset->Count(), d_quantizer);
}
else {
printf("Add support for testing with %d PQ dimensions\n", quan_vecset->Dimension());
}
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemcpy(res_idx, d_res_idx, quan_vecset->Count()*sizeof(int), cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaMemcpy(res_dist, d_res_dist, quan_vecset->Count()*sizeof(float), cudaMemcpyDeviceToHost));
// Optional check results for debugging
}
void GPU_pq_alltype(std::shared_ptr<VectorSet>& real_vecset, std::shared_ptr<VectorSet>& quan_vecset, DistCalcMethod distMethod, std::shared_ptr<COMMON::IQuantizer>& quantizer, int* res_idx, float* res_dist) {
GPU_top1_pq<float>(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist);
}

255
Test/cuda/knn_tests.cu Normal file
Просмотреть файл

@ -0,0 +1,255 @@
#include "common.hxx"
#include "inc/Core/Common/cuda/KNN.hxx"
template<typename T, typename SUMTYPE, int Dim, int metric>
__global__ void test_KNN(PointSet<T>* ps, int* results, int rows, int K) {
extern __shared__ char sharememory[];
DistPair<SUMTYPE>* threadList = (&((DistPair<SUMTYPE>*)sharememory)[K*threadIdx.x]);
T query[Dim];
T candidate_vec[Dim];
DistPair<SUMTYPE> target;
DistPair<SUMTYPE> candidate;
DistPair<SUMTYPE> temp;
bool good;
SUMTYPE max_dist = INFTY<SUMTYPE>();
int read_id, write_id;
SUMTYPE (*dist_comp)(T*,T*) = &dist<T,SUMTYPE,Dim,metric>;
for(size_t i=blockIdx.x*blockDim.x + threadIdx.x; i<rows; i+=blockDim.x*gridDim.x) {
for(int j=0; j<Dim; ++j) {
query[j] = ps->getVec(i)[j];
}
for(int k=0; k<K; k++) {
threadList[k].dist=INFTY<SUMTYPE>();
}
for(size_t j=0; j<rows; ++j) {
good = true;
candidate.idx = j;
candidate.dist = dist_comp(query, ps->getVec(j));
if(max_dist > candidate.dist) {
for(read_id=0; candidate.dist > threadList[read_id].dist && good; read_id++) {
if(violatesRNG<T,SUMTYPE>(candidate_vec, ps->getVec(threadList[read_id].idx), candidate.dist, dist_comp)) {
good = false;
}
}
if(good) {
target = threadList[read_id];
threadList[read_id] = candidate;
read_id++;
for(write_id = read_id; read_id < K && threadList[read_id].idx != -1; read_id++) {
if(!violatesRNG<T, SUMTYPE>(ps->getVec(threadList[read_id].idx), candidate_vec, threadList[read_id].dist, dist_comp)) {
if(read_id == write_id) {
temp = threadList[read_id];
threadList[write_id] = target;
target = temp;
}
else {
threadList[write_id] = target;
target = threadList[read_id];
}
write_id++;
}
}
if(write_id < K) {
threadList[write_id] = target;
write_id++;
}
for(int k=write_id; k<K && threadList[k].idx != -1; k++) {
threadList[k].dist = INFTY<SUMTYPE>();
threadList[k].idx = -1;
}
max_dist = threadList[K-1].dist;
}
}
}
for(size_t j=0; j<K; j++) {
results[(size_t)(i)*K+j] = threadList[j].idx;
}
}
}
template<typename T, typename SUMTYPE, int dim, int K>
int GPUBuildKNNCosineTest(int rows) {
T* data = create_dataset<T>(rows, dim);
T* d_data;
CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice));
int* d_results;
CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(int)));
PointSet<T> h_ps;
h_ps.dim = dim;
h_ps.data = d_data;
PointSet<T>* d_ps;
CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
test_KNN<T,SUMTYPE,dim,(int)DistMetric::Cosine><<<1024, 64, K*64*sizeof(DistPair<SUMTYPE>)>>>(d_ps, d_results, rows, K);
CUDA_CHECK(cudaDeviceSynchronize());
int* h_results = new int[K*rows];
CUDA_CHECK(cudaMemcpy(h_results, d_results, rows*K*sizeof(int), cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaDeviceSynchronize());
// Verify that the neighbor list of each vector is ordered correctly
for(int i=0; i<rows; ++i) {
for(int j=0; j<K-1; ++j) {
int neighborId = h_results[i*K+j];
int nextNeighborId = h_results[i*K+j+1];
if(neighborId != -1 && nextNeighborId != -1) {
SUMTYPE neighborDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeCosineDistance<T>(&data[i*dim], &data[neighborId*dim], dim));
SUMTYPE nextDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeCosineDistance<T>(&data[i*dim], &data[nextNeighborId*dim], dim));
if(neighborDist > nextDist) {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Neighbor list not in ascending distance order. i:%d, neighbor:%d (dist:%f), next:%d (dist:%f)\n", i, neighborId, neighborDist, nextNeighborId, nextDist);
return 1;
}
}
}
}
CUDA_CHECK(cudaFree(d_data));
CUDA_CHECK(cudaFree(d_results));
CUDA_CHECK(cudaFree(d_ps));
return 0;
}
int GPUBuildKNNCosineTest() {
int errors = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype tests...\n");
errors += GPUBuildKNNCosineTest<float, float, 10, 10>(1000);
errors += GPUBuildKNNCosineTest<float, float, 100, 10>(1000);
errors += GPUBuildKNNCosineTest<float, float, 200, 10>(1000);
errors += GPUBuildKNNCosineTest<float, float, 384, 10>(1000);
errors += GPUBuildKNNCosineTest<float, float, 1024, 10>(1000);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int32 datatype tests...\n");
errors += GPUBuildKNNCosineTest<int, int, 10, 10>(1000);
errors += GPUBuildKNNCosineTest<int, int, 100, 10>(1000);
errors += GPUBuildKNNCosineTest<int, int, 200, 10>(1000);
errors += GPUBuildKNNCosineTest<int, int, 384, 10>(1000);
errors += GPUBuildKNNCosineTest<int, int, 1024, 10>(1000);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype tests...\n");
errors += GPUBuildKNNCosineTest<int8_t, int32_t, 100, 10>(1000);
errors += GPUBuildKNNCosineTest<int8_t, int32_t, 200, 10>(1000);
errors += GPUBuildKNNCosineTest<int8_t, int32_t, 384, 10>(1000);
errors += GPUBuildKNNCosineTest<int8_t, int32_t, 1024, 10>(1000);
return errors;
}
template<typename T, typename SUMTYPE, int dim, int K>
int GPUBuildKNNL2Test(int rows) {
T* data = create_dataset<T>(rows, dim);
T* d_data;
CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice));
int* d_results;
CUDA_CHECK(cudaMalloc(&d_results, rows*K*sizeof(int)));
PointSet<T> h_ps;
h_ps.dim = dim;
h_ps.data = d_data;
PointSet<T>* d_ps;
CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
test_KNN<T,SUMTYPE,dim,(int)DistMetric::L2><<<1024, 64, K*64*sizeof(DistPair<SUMTYPE>)>>>(d_ps, d_results, rows, K);
CUDA_CHECK(cudaDeviceSynchronize());
int* h_results = new int[K*rows];
CUDA_CHECK(cudaMemcpy(h_results, d_results, rows*K*sizeof(int), cudaMemcpyDeviceToHost));
CUDA_CHECK(cudaDeviceSynchronize());
// Verify that the neighbor list of each vector is ordered correctly
for(int i=0; i<rows; ++i) {
for(int j=0; j<K-1; ++j) {
int neighborId = h_results[i*K+j];
int nextNeighborId = h_results[i*K+j+1];
if(neighborId != -1 && nextNeighborId != -1) {
SUMTYPE neighborDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeL2Distance<T>(&data[i*dim], &data[neighborId*dim], dim));
SUMTYPE nextDist = (SUMTYPE)(SPTAG::COMMON::DistanceUtils::ComputeL2Distance<T>(&data[i*dim], &data[nextNeighborId*dim], dim));
if(neighborDist > nextDist) {
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Error, "Neighbor list not in ascending distance order. i:%d, neighbor:%d (dist:%f), next:%d (dist:%f)\n", i, neighborId, neighborDist, nextNeighborId, nextDist);
return 1;
}
}
}
}
CUDA_CHECK(cudaFree(d_data));
CUDA_CHECK(cudaFree(d_results));
CUDA_CHECK(cudaFree(d_ps));
return 0;
}
int GPUBuildKNNL2Test() {
int errors = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype tests...\n");
errors += GPUBuildKNNL2Test<float, float, 10, 10>(1000);
errors += GPUBuildKNNL2Test<float, float, 100, 10>(1000);
errors += GPUBuildKNNL2Test<float, float, 200, 10>(1000);
errors += GPUBuildKNNL2Test<float, float, 384, 10>(1000);
errors += GPUBuildKNNL2Test<float, float, 1024, 10>(1000);
CHECK_ERRS(errors)
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int32 datatype tests...\n");
errors += GPUBuildKNNL2Test<int, int, 10, 10>(1000);
errors += GPUBuildKNNL2Test<int, int, 100, 10>(1000);
errors += GPUBuildKNNL2Test<int, int, 200, 10>(1000);
errors += GPUBuildKNNL2Test<int, int, 384, 10>(1000);
errors += GPUBuildKNNL2Test<int, int, 1024, 10>(1000);
CHECK_ERRS(errors)
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Int8 datatype tests...\n");
errors += GPUBuildKNNL2Test<int8_t, int32_t, 100, 10>(1000);
errors += GPUBuildKNNL2Test<int8_t, int32_t, 200, 10>(1000);
errors += GPUBuildKNNL2Test<int8_t, int32_t, 384, 10>(1000);
errors += GPUBuildKNNL2Test<int8_t, int32_t, 1024, 10>(1000);
CHECK_ERRS(errors)
}
int GPUBuildKNNTest() {
int errors = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting KNN Cosine metric tests\n");
errors += GPUBuildKNNCosineTest();
CHECK_ERRS(errors)
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting KNN L2 metric tests\n");
errors += GPUBuildKNNL2Test();
CHECK_ERRS(errors)
return errors;
}
int GPUBuildTPTreeTest();

593
Test/cuda/pq_perf.cpp Normal file
Просмотреть файл

@ -0,0 +1,593 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
//#include "../inc/Test.h"
#define BOOST_TEST_MODULE GPU
#include <iostream>
#include <boost/test/included/unit_test.hpp>
#include <boost/filesystem.hpp>
#include <limits>
#include <chrono>
#include "inc/Core/Common/cuda/params.h"
#include <random>
#include "inc/Helper/VectorSetReader.h"
#include "inc/Core/Common/PQQuantizer.h"
#include "inc/Core/VectorIndex.h"
#include "inc/Core/Common/CommonUtils.h"
#include "inc/Core/Common/QueryResultSet.h"
#include "inc/Core/Common/DistanceUtils.h"
#include <thread>
#include <iostream>
#include <unordered_set>
#include <ctime>
using namespace SPTAG;
/*
template <typename T>
void Search(std::shared_ptr<VectorIndex>& vecIndex, std::shared_ptr<VectorSet>& queryset, int k, std::shared_ptr<VectorSet>& truth)
{
std::vector<SPTAG::COMMON::QueryResultSet<T>> res(queryset->Count(), SPTAG::COMMON::QueryResultSet<T>(nullptr, k * 2));
auto t1 = std::chrono::high_resolution_clock::now();
for (SizeType i = 0; i < queryset->Count(); i++)
{
res[i].Reset();
res[i].SetTarget((const T*)queryset->GetVector(i), vecIndex->m_pQuantizer);
vecIndex->SearchIndex(res[i]);
}
auto t2 = std::chrono::high_resolution_clock::now();
std::cout << "Search time: " << (std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() / (float)(queryset->Count())) << "us" << std::endl;
float eps = 1e-6f, recall = 0;
int truthDimension = min(k, truth->Dimension());
for (SizeType i = 0; i < queryset->Count(); i++) {
SizeType* nn = (SizeType*)(truth->GetVector(i));
std::vector<bool> visited(2 * k, false);
for (int j = 0; j < truthDimension; j++) {
float truthdist = vecIndex->ComputeDistance(res[i].GetQuantizedTarget(), vecIndex->GetSample(nn[j]));
for (int l = 0; l < k*2; l++) {
if (visited[l]) continue;
//std::cout << res[i].GetResult(l)->Dist << " " << truthdist << std::endl;
if (res[i].GetResult(l)->VID == nn[j]) {
recall += 1.0;
visited[l] = true;
break;
}
else if (fabs(res[i].GetResult(l)->Dist - truthdist) <= eps) {
recall += 1.0;
visited[l] = true;
break;
}
}
}
}
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Recall %d@%d: %f\n", truthDimension, k*2, recall / queryset->Count() / truthDimension);
}
*/
/*
template<typename T>
std::shared_ptr<VectorIndex> PerfBuild(IndexAlgoType algo, std::string distCalcMethod, std::shared_ptr<VectorSet>& vec, std::shared_ptr<MetadataSet>& meta, std::shared_ptr<VectorSet>& queryset, int k, std::shared_ptr<VectorSet>& truth, std::string out, std::shared_ptr<COMMON::IQuantizer> quantizer)
{
std::shared_ptr<VectorIndex> vecIndex = SPTAG::VectorIndex::CreateInstance(algo, SPTAG::GetEnumValueType<T>());
vecIndex->SetQuantizer(quantizer);
BOOST_CHECK(nullptr != vecIndex);
if (algo == IndexAlgoType::KDT) vecIndex->SetParameter("KDTNumber", "2");
vecIndex->SetParameter("DistCalcMethod", distCalcMethod);
vecIndex->SetParameter("NumberOfThreads", "12");
vecIndex->SetParameter("RefineIterations", "3");
vecIndex->SetParameter("MaxCheck", "4096");
vecIndex->SetParameter("MaxCheckForRefineGraph", "8192");
BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->BuildIndex(vec, meta, true));
BOOST_CHECK(SPTAG::ErrorCode::Success == vecIndex->SaveIndex(out));
//Search<T>(vecIndex, queryset, k, truth);
return vecIndex;
}
*/
template <typename R>
void GenerateReconstructData(std::shared_ptr<VectorSet>& real_vecset, std::shared_ptr<VectorSet>& rec_vecset, std::shared_ptr<VectorSet>& quan_vecset, std::shared_ptr<MetadataSet>& metaset, std::shared_ptr<VectorSet>& queryset, std::shared_ptr<VectorSet>& truth, DistCalcMethod distCalcMethod, int k, std::shared_ptr<COMMON::IQuantizer>& quantizer)
{
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<R> dist(-1000, 1000);
int n = 20000, q = 200;
int m = 256;
int M = 128;
int Ks = 256;
int QuanDim = m / M;
std::string CODEBOOK_FILE = "quantest_quantizer.bin";
if (fileexists("quantest_vector.bin") && fileexists("quantest_query.bin")) {
std::shared_ptr<Helper::ReaderOptions> options(new Helper::ReaderOptions(GetEnumValueType<R>(), m, VectorFileType::DEFAULT));
auto vectorReader = Helper::VectorSetReader::CreateInstance(options);
if (ErrorCode::Success != vectorReader->LoadFile("quantest_vector.bin"))
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n");
exit(1);
}
real_vecset = vectorReader->GetVectorSet();
if (ErrorCode::Success != vectorReader->LoadFile("quantest_query.bin"))
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read query file.\n");
exit(1);
}
queryset = vectorReader->GetVectorSet();
}
else {
ByteArray real_vec = ByteArray::Alloc(sizeof(R) * n * m);
for (int i = 0; i < n * m; i++) {
((R*)real_vec.Data())[i] = (R)(dist(gen));
}
real_vecset.reset(new BasicVectorSet(real_vec, GetEnumValueType<R>(), m, n));
// real_vecset->Save("quantest_vector.bin");
ByteArray real_query = ByteArray::Alloc(sizeof(R) * q * m);
for (int i = 0; i < q * m; i++) {
((R*)real_query.Data())[i] = (R)(dist(gen));
}
queryset.reset(new BasicVectorSet(real_query, GetEnumValueType<R>(), m, q));
// queryset->Save("quantest_query.bin");
}
if (fileexists(("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)).c_str())) {
std::shared_ptr<Helper::ReaderOptions> options(new Helper::ReaderOptions(GetEnumValueType<float>(), k, VectorFileType::DEFAULT));
auto vectorReader = Helper::VectorSetReader::CreateInstance(options);
if (ErrorCode::Success != vectorReader->LoadFile("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod)))
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read truth file.\n");
exit(1);
}
truth = vectorReader->GetVectorSet();
}
else {
omp_set_num_threads(5);
ByteArray tru = ByteArray::Alloc(sizeof(SizeType) * queryset->Count() * k);
#pragma omp parallel for
for (SizeType i = 0; i < queryset->Count(); ++i)
{
SizeType* neighbors = ((SizeType*)tru.Data()) + i * k;
COMMON::QueryResultSet<R> res((const R*)queryset->GetVector(i), k);
for (SizeType j = 0; j < real_vecset->Count(); j++)
{
float dist = COMMON::DistanceUtils::ComputeDistance(res.GetTarget(), reinterpret_cast<R*>(real_vecset->GetVector(j)), queryset->Dimension(), distCalcMethod);
res.AddPoint(j, dist);
}
res.SortResult();
for (int j = 0; j < k; j++) neighbors[j] = res.GetResult(j)->VID;
}
truth.reset(new BasicVectorSet(tru, GetEnumValueType<float>(), k, queryset->Count()));
// truth->Save("quantest_truth." + SPTAG::Helper::Convert::ConvertToString(distCalcMethod));
}
if (fileexists(CODEBOOK_FILE.c_str()) && fileexists("quantest_quan_vector.bin") && fileexists("quantest_rec_vector.bin")) {
auto ptr = SPTAG::f_createIO();
if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) {
BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error");
}
quantizer->LoadIQuantizer(ptr);
BOOST_ASSERT(quantizer);
std::shared_ptr<Helper::ReaderOptions> options(new Helper::ReaderOptions(GetEnumValueType<R>(), m, VectorFileType::DEFAULT));
auto vectorReader = Helper::VectorSetReader::CreateInstance(options);
if (ErrorCode::Success != vectorReader->LoadFile("quantest_rec_vector.bin"))
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n");
exit(1);
}
rec_vecset = vectorReader->GetVectorSet();
std::shared_ptr<Helper::ReaderOptions> quanOptions(new Helper::ReaderOptions(GetEnumValueType<std::uint8_t>(), M, VectorFileType::DEFAULT));
vectorReader = Helper::VectorSetReader::CreateInstance(quanOptions);
if (ErrorCode::Success != vectorReader->LoadFile("quantest_quan_vector.bin"))
{
SPTAGLIB_LOG(Helper::LogLevel::LL_Error, "Failed to read vector file.\n");
exit(1);
}
quan_vecset = vectorReader->GetVectorSet();
}
else {
omp_set_num_threads(16);
std::cout << "Building codebooks!" << std::endl;
R* vecs = (R*)(real_vecset->GetData());
std::unique_ptr<R[]> codebooks = std::make_unique<R[]>(M * Ks * QuanDim);
std::unique_ptr<int[]> belong(new int[n]);
for (int i = 0; i < M; i++) {
R* kmeans = codebooks.get() + i * Ks * QuanDim;
for (int j = 0; j < Ks; j++) {
std::memcpy(kmeans + j * QuanDim, vecs + j * m + i * QuanDim, sizeof(R) * QuanDim);
}
int cnt = 100;
while (cnt--) {
//calculate cluster
#pragma omp parallel for
for (int ii = 0; ii < n; ii++) {
double min_dis = 1e9;
int min_id = 0;
for (int jj = 0; jj < Ks; jj++) {
double now_dis = COMMON::DistanceUtils::ComputeDistance(vecs + ii * m + i * QuanDim, kmeans + jj * QuanDim, QuanDim, DistCalcMethod::L2);
if (now_dis < min_dis) {
min_dis = now_dis;
min_id = jj;
}
}
belong[ii] = min_id;
}
//recalculate kmeans
std::memset(kmeans, 0, sizeof(R) * Ks * QuanDim);
#pragma omp parallel for
for (int ii = 0; ii < Ks; ii++) {
int num = 0;
for (int jj = 0; jj < n; jj++) {
if (belong[jj] == ii) {
num++;
for (int kk = 0; kk < QuanDim; kk++) {
kmeans[ii * QuanDim + kk] += vecs[jj * m + i * QuanDim + kk];
}
}
}
for (int jj = 0; jj < QuanDim; jj++) {
kmeans[ii * QuanDim + jj] /= num;
}
}
}
}
std::cout << "Building Finish!" << std::endl;
quantizer = std::make_shared<SPTAG::COMMON::PQQuantizer<R>>(M, Ks, QuanDim, false, std::move(codebooks));
printf("After built, pq type:%d\n", (int)quantizer->GetQuantizerType());
auto ptr = SPTAG::f_createIO();
if (ptr == nullptr || !ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::out)) {
BOOST_ASSERT("Canot Open CODEBOOK_FILE to write!" == "Error");
}
// quantizer->SaveQuantizer(ptr);
ptr->ShutDown();
if (!ptr->Initialize(CODEBOOK_FILE.c_str(), std::ios::binary | std::ios::in)) {
BOOST_ASSERT("Canot Open CODEBOOK_FILE to read!" == "Error");
}
quantizer->LoadIQuantizer(ptr);
BOOST_ASSERT(quantizer);
rec_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(R) * n * m), GetEnumValueType<R>(), m, n));
quan_vecset.reset(new BasicVectorSet(ByteArray::Alloc(sizeof(std::uint8_t) * n * M), GetEnumValueType<std::uint8_t>(), M, n));
for (int i = 0; i < n; i++) {
auto nvec = &vecs[i * m];
quantizer->QuantizeVector(nvec, (uint8_t*)quan_vecset->GetVector(i));
quantizer->ReconstructVector((uint8_t*)quan_vecset->GetVector(i), rec_vecset->GetVector(i));
}
// quan_vecset->Save("quantest_quan_vector.bin");
// rec_vecset->Save("quantest_rec_vector.bin");
}
}
void GPU_nopq_alltype(std::shared_ptr<VectorSet>& real_vecset, DistCalcMethod distMethod, int* res_idx, float* res_dist);
void GPU_pq_alltype(std::shared_ptr<VectorSet>& real_vecset, std::shared_ptr<VectorSet>& quan_vecset, DistCalcMethod distMethod, std::shared_ptr<COMMON::IQuantizer>& quantizer, int* res_idx, float* res_dist);
bool DEBUG_REPORT = false;
template<typename R>
void CPU_top1_nopq(std::shared_ptr<VectorSet>& real_vecset, DistCalcMethod distMethod, int numThreads, int* result, float* res_dist, bool randomized) {
float testDist;
// float* nearest = new float[real_vecset->Count()];
printf("CPU top1 nopq - real dim:%d\n", real_vecset->Dimension());
if(randomized) {
srand(time(NULL));
#pragma omp parallel for num_threads(numThreads)
for(int i=0; i<real_vecset->Count(); ++i) {
int idx;
res_dist[i] = std::numeric_limits<float>::max();
for(int j=0; j<real_vecset->Count(); ++j) {
idx = rand() % real_vecset->Count();
if(i==idx) continue;
testDist = COMMON::DistanceUtils::ComputeDistance(reinterpret_cast<R*>(real_vecset->GetVector(i)), reinterpret_cast<R*>(real_vecset->GetVector(idx)), real_vecset->Dimension(), distMethod);
if(testDist < res_dist[i]) {
res_dist[i] = testDist;
result[i] = idx;
}
}
}
}
else {
#pragma omp parallel for num_threads(numThreads)
for(int i=0; i<real_vecset->Count(); ++i) {
res_dist[i] = std::numeric_limits<float>::max();
for(int j=0; j<real_vecset->Count(); ++j) {
if(i==j) continue;
testDist = COMMON::DistanceUtils::ComputeDistance(reinterpret_cast<R*>(real_vecset->GetVector(i)), reinterpret_cast<R*>(real_vecset->GetVector(j)), real_vecset->Dimension(), distMethod);
if(testDist < res_dist[i]) {
res_dist[i] = testDist;
result[i] = j;
}
}
}
}
/*
if(DEBUG_REPORT) {
for(int i=0; i<real_vecset->Count(); ++i) {
printf("%f\n", nearest[i]);
}
}
*/
}
template<typename R>
void CPU_top1_pq(std::shared_ptr<VectorSet>& real_vecset, std::shared_ptr<VectorSet>& quan_vecset, DistCalcMethod distMethod, std::shared_ptr<COMMON::IQuantizer>& quantizer, int numThreads, int* res_idx, float* res_dist, bool randomized) {
float testDist;
// float* nearest = new float[real_vecset->Count()];
printf("CPU top1 PQ - PQ dim:%d, count:%d\n", quan_vecset->Dimension(), quan_vecset->Count());
std::shared_ptr<VectorIndex> vecIndex = SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType<uint8_t>());
vecIndex->SetQuantizer(quantizer);
vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod));
if(randomized) {
int idx;
#pragma omp parallel for num_threads(numThreads)
for(int i=0; i<quan_vecset->Count(); ++i) {
res_dist[i] = std::numeric_limits<float>::max();
for(int j=0; j<quan_vecset->Count(); ++j) {
idx = rand() % (quan_vecset->Count());
if(i==idx) continue;
testDist = vecIndex->ComputeDistance(quan_vecset->GetVector(i), quan_vecset->GetVector(idx));
if(testDist < res_dist[i]) {
res_dist[i] = testDist;
res_idx[i] = idx;
}
}
}
}
else {
#pragma omp parallel for num_threads(numThreads)
for(int i=0; i<quan_vecset->Count(); ++i) {
res_dist[i] = std::numeric_limits<float>::max();
for(int j=0; j<quan_vecset->Count(); ++j) {
if(i==j) continue;
testDist = vecIndex->ComputeDistance(quan_vecset->GetVector(i), quan_vecset->GetVector(j));
if(testDist < res_dist[i]) {
res_dist[i] = testDist;
res_idx[i] = j;
}
}
}
}
}
#define EPS 0.01
void verify_results(int* gt_idx, float* gt_dist, int* res_idx, float* res_dist, int N) {
// Verify both CPU baselines have same result
for(int i=0; i<N; ++i) {
if(gt_idx[i] != res_idx[i] || gt_dist[i] > res_dist[i]*(1+EPS)) {
printf("mismatch! gt:%d (%f), result:%d (%f)\n", gt_idx[i], gt_dist[i], res_idx[i], res_dist[i]);
}
}
}
void compute_accuracy(int* gt_idx, float* gt_dist, int* res_idx, float* res_dist, int N) {
float matches=0.0;
float dist_sum=0.0;
float total_dists=0.0;
for(int i=0; i<N; ++i) {
if(gt_idx[i] == res_idx[i]) {
matches++;
}
total_dists += gt_dist[i];
dist_sum += abs(gt_dist[i]-res_dist[i]);
}
printf("KNN accuracy:%0.3f, avg. distance error:%4.3e\n", matches/(float)N, dist_sum/total_dists);
}
template <typename R>
void DistancePerfSuite(IndexAlgoType algo, DistCalcMethod distMethod)
{
std::shared_ptr<VectorSet> real_vecset, rec_vecset, quan_vecset, queryset, truth;
std::shared_ptr<MetadataSet> metaset;
std::shared_ptr<COMMON::IQuantizer> quantizer;
GenerateReconstructData<R>(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, quantizer);
std::shared_ptr<VectorIndex> vecIndex = SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType<uint8_t>());
vecIndex->SetQuantizer(quantizer);
vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod));
// printf("Truth dimension:%d, count:%d\n", truth->Dimension(), truth->Count());
printf("numSubvec:%d\n", quantizer->GetNumSubvectors());
printf("quantizer type:%d\n", (int)(quantizer->GetQuantizerType()));
int* gt_idx = new int[real_vecset->Count()];
float* gt_dist = new float[real_vecset->Count()];
int* res_idx = new int[real_vecset->Count()];
float* res_dist = new float[real_vecset->Count()];
// BASELINE non-PQ perf timing:
// CPU distance comparison timing with non-quantized
auto start_t = std::chrono::high_resolution_clock::now();
CPU_top1_nopq<R>(real_vecset, distMethod, 1, gt_idx, gt_dist, false);
auto end_t = std::chrono::high_resolution_clock::now();
double CPU_baseline_t = GET_CHRONO_TIME(start_t, end_t);
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_nopq<R>(real_vecset, distMethod, 16, res_idx, res_dist, false);
end_t = std::chrono::high_resolution_clock::now();
double CPU_parallel_t = GET_CHRONO_TIME(start_t, end_t);
printf("Verifying correctness of CPU non-pq baseline results...\n");
verify_results(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// GPU distance comparisons using non-quantized
start_t = std::chrono::high_resolution_clock::now();
GPU_nopq_alltype(real_vecset, distMethod, res_idx, res_dist);
end_t = std::chrono::high_resolution_clock::now();
double GPU_baseline_t = GET_CHRONO_TIME(start_t, end_t);
printf("Verifying correctness of GPU non-pq baseline results...\n");
verify_results(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// Time CPU all-to-all distance comparisons between quan_vecset
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_pq<float>(real_vecset, quan_vecset, distMethod, quantizer, 1, res_idx, res_dist, false);
end_t = std::chrono::high_resolution_clock::now();
double CPU_PQ_t = GET_CHRONO_TIME(start_t, end_t);
printf("CPU PQ single-threaded\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_pq<float>(real_vecset, quan_vecset, distMethod, quantizer, 16, res_idx, res_dist, false);
end_t = std::chrono::high_resolution_clock::now();
double CPU_PQ_parallel_t = GET_CHRONO_TIME(start_t, end_t);
printf("CPU PQ 16 threads\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// Time each GPU method of all-to-all distance between quan_vecset
start_t = std::chrono::high_resolution_clock::now();
GPU_pq_alltype(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist);
end_t = std::chrono::high_resolution_clock::now();
double GPU_PQ_t = GET_CHRONO_TIME(start_t, end_t);
printf("GPU PQ threads\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 1-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_baseline_t, CPU_PQ_t);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 16-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_parallel_t, CPU_PQ_parallel_t);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU time - baseline:%0.3lf, PQ:%0.3lf\n", GPU_baseline_t, GPU_PQ_t);
//LoadReconstructData<R>(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10);
/*
auto real_idx = PerfBuild<R>(algo, Helper::Convert::ConvertToString<DistCalcMethod>(distMethod), real_vecset, metaset, queryset, 10, truth, "real_idx", nullptr);
Search<R>(real_idx, queryset, 10, truth);
auto rec_idx = PerfBuild<R>(algo, Helper::Convert::ConvertToString<DistCalcMethod>(distMethod), rec_vecset, metaset, queryset, 10, truth, "rec_idx", nullptr);
Search<R>(rec_idx, queryset, 10, truth);
auto quan_idx = PerfBuild<std::uint8_t>(algo, Helper::Convert::ConvertToString<DistCalcMethod>(distMethod), quan_vecset, metaset, queryset, 10, truth, "quan_idx", quantizer);
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with SDC");
Search<R>(quan_idx, queryset, 10, truth);
SPTAGLIB_LOG(Helper::LogLevel::LL_Info, "Test search with ADC");
quan_idx->SetQuantizerADC(true);
Search<R>(quan_idx, queryset, 10, truth);
*/
}
template <typename R>
void DistancePerfRandomized(IndexAlgoType algo, DistCalcMethod distMethod)
{
std::shared_ptr<VectorSet> real_vecset, rec_vecset, quan_vecset, queryset, truth;
std::shared_ptr<MetadataSet> metaset;
std::shared_ptr<COMMON::IQuantizer> quantizer;
GenerateReconstructData<R>(real_vecset, rec_vecset, quan_vecset, metaset, queryset, truth, distMethod, 10, quantizer);
std::shared_ptr<VectorIndex> vecIndex = SPTAG::VectorIndex::CreateInstance(IndexAlgoType::BKT, SPTAG::GetEnumValueType<uint8_t>());
vecIndex->SetQuantizer(quantizer);
vecIndex->SetParameter("DistCalcMethod", SPTAG::Helper::Convert::ConvertToString(distMethod));
// printf("Truth dimension:%d, count:%d\n", truth->Dimension(), truth->Count());
printf("numSubvec:%d\n", quantizer->GetNumSubvectors());
printf("quantizer type:%d\n", (int)(quantizer->GetQuantizerType()));
int* gt_idx = new int[real_vecset->Count()];
float* gt_dist = new float[real_vecset->Count()];
int* res_idx = new int[real_vecset->Count()];
float* res_dist = new float[real_vecset->Count()];
// BASELINE non-PQ perf timing:
// CPU distance comparison timing with non-quantized
auto start_t = std::chrono::high_resolution_clock::now();
CPU_top1_nopq<R>(real_vecset, distMethod, 1, gt_idx, gt_dist, true);
auto end_t = std::chrono::high_resolution_clock::now();
double CPU_baseline_t = GET_CHRONO_TIME(start_t, end_t);
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_nopq<R>(real_vecset, distMethod, 16, res_idx, res_dist, true);
end_t = std::chrono::high_resolution_clock::now();
double CPU_parallel_t = GET_CHRONO_TIME(start_t, end_t);
printf("Randomized accuracy between CPU runs...\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// GPU distance comparisons using non-quantized
start_t = std::chrono::high_resolution_clock::now();
GPU_nopq_alltype(real_vecset, distMethod, res_idx, res_dist);
end_t = std::chrono::high_resolution_clock::now();
double GPU_baseline_t = GET_CHRONO_TIME(start_t, end_t);
printf("Accuracy of randomized GPU non-pq run\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// Time CPU all-to-all distance comparisons between quan_vecset
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_pq<float>(real_vecset, quan_vecset, distMethod, quantizer, 1, res_idx, res_dist, true);
end_t = std::chrono::high_resolution_clock::now();
double CPU_PQ_t = GET_CHRONO_TIME(start_t, end_t);
printf("CPU PQ single-threaded\n");
start_t = std::chrono::high_resolution_clock::now();
CPU_top1_pq<float>(real_vecset, quan_vecset, distMethod, quantizer, 16, res_idx, res_dist, true);
end_t = std::chrono::high_resolution_clock::now();
double CPU_PQ_parallel_t = GET_CHRONO_TIME(start_t, end_t);
printf("CPU PQ 16 threads\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
// Time each GPU method of all-to-all distance between quan_vecset
start_t = std::chrono::high_resolution_clock::now();
GPU_pq_alltype(real_vecset, quan_vecset, distMethod, quantizer, res_idx, res_dist);
end_t = std::chrono::high_resolution_clock::now();
double GPU_PQ_t = GET_CHRONO_TIME(start_t, end_t);
printf("GPU PQ threads\n");
compute_accuracy(gt_idx, gt_dist, res_idx, res_dist, real_vecset->Count());
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 1-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_baseline_t, CPU_PQ_t);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "CPU 16-thread time - baseline:%0.3lf, PQ:%0.3lf\n", CPU_parallel_t, CPU_PQ_parallel_t);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "GPU time - baseline:%0.3lf, PQ:%0.3lf\n", GPU_baseline_t, GPU_PQ_t);
}
BOOST_AUTO_TEST_SUITE(GPUPQPerfTest)
BOOST_AUTO_TEST_CASE(GPUPQCosineTest)
{
DistancePerfSuite<float>(IndexAlgoType::BKT, DistCalcMethod::L2);
DistancePerfRandomized<float>(IndexAlgoType::BKT, DistCalcMethod::L2);
}
BOOST_AUTO_TEST_SUITE_END()

127
Test/cuda/tptree_tests.cu Normal file
Просмотреть файл

@ -0,0 +1,127 @@
#include "common.hxx"
#include "inc/Core/Common/cuda/TPtree.hxx"
template<typename T, typename SUMTYPE, int dim>
int TPTKernelsTest(int rows) {
int errs = 0;
T* data = create_dataset<T>(rows, dim);
T* d_data;
CUDA_CHECK(cudaMalloc(&d_data, dim*rows*sizeof(T)));
CUDA_CHECK(cudaMemcpy(d_data, data, dim*rows*sizeof(T), cudaMemcpyHostToDevice));
PointSet<T> h_ps;
h_ps.dim = dim;
h_ps.data = d_data;
PointSet<T>* d_ps;
CUDA_CHECK(cudaMalloc(&d_ps, sizeof(PointSet<T>)));
CUDA_CHECK(cudaMemcpy(d_ps, &h_ps, sizeof(PointSet<T>), cudaMemcpyHostToDevice));
int levels = (int)std::log2(rows/100); // TPT levels
TPtree* tptree = new TPtree;
tptree->initialize(rows, levels, dim);
// Check that tptree structure properly initialized
CHECK_VAL(tptree->Dim,dim,errs)
CHECK_VAL(tptree->levels,levels,errs)
CHECK_VAL(tptree->N,rows,errs)
CHECK_VAL(tptree->num_leaves,pow(2,levels),errs)
// Create TPT structure and random weights
KEYTYPE* h_weights = new KEYTYPE[tptree->levels*tptree->Dim];
for(int i=0; i<tptree->levels*tptree->Dim; ++i) {
h_weights[i] = ((rand()%2)*2)-1;
}
tptree->reset();
CUDA_CHECK(cudaMemcpy(tptree->weight_list, h_weights, tptree->levels*tptree->Dim*sizeof(KEYTYPE), cudaMemcpyHostToDevice));
curandState* states;
CUDA_CHECK(cudaMalloc(&states, 1024*32*sizeof(curandState)));
initialize_rands<<<1024,32>>>(states, 0);
int nodes_on_level=1;
for(int i=0; i<tptree->levels; ++i) {
find_level_sum<T><<<1024,32>>>(d_ps, tptree->weight_list, dim, tptree->node_ids, tptree->split_keys, tptree->node_sizes, rows, nodes_on_level, i, rows);
CUDA_CHECK(cudaDeviceSynchronize());
float* split_key_sum = new float[nodes_on_level];
CUDA_CHECK(cudaMemcpy(split_key_sum, &tptree->split_keys[nodes_on_level-1], nodes_on_level*sizeof(float), cudaMemcpyDeviceToHost)); // Copy the sum to compare with mean computed later
int* node_sizes = new int[nodes_on_level];
CUDA_CHECK(cudaMemcpy(node_sizes, &tptree->node_sizes[nodes_on_level-1], nodes_on_level*sizeof(int), cudaMemcpyDeviceToHost));
compute_mean<<<1024, 32>>>(tptree->split_keys, tptree->node_sizes, tptree->num_nodes);
CUDA_CHECK(cudaDeviceSynchronize());
// Check the mean values for each node
for(int j=0; j<nodes_on_level; ++j) {
GPU_CHECK_VAL(&tptree->split_keys[nodes_on_level-1+j],split_key_sum[j]/(float)(node_sizes[j]),float,errs)
}
update_node_assignments<T><<<1024, 32>>>(d_ps, tptree->weight_list, tptree->node_ids, tptree->split_keys, tptree->node_sizes, rows, i, dim);
CUDA_CHECK(cudaDeviceSynchronize());
nodes_on_level *= 2;
}
count_leaf_sizes<<<1024, 32>>>(tptree->leafs, tptree->node_ids, rows, tptree->num_nodes - tptree->num_leaves);
CUDA_CHECK(cudaDeviceSynchronize());
// Check that total leaf node sizes equals total vectors
int total_leaf_sizes=0;
for(int j=0; j<tptree->num_leaves; ++j) {
LeafNode temp_leaf;
CUDA_CHECK(cudaMemcpy(&temp_leaf, &tptree->leafs[j], sizeof(LeafNode), cudaMemcpyDeviceToHost));
total_leaf_sizes+=temp_leaf.size;
}
CHECK_VAL_LT(total_leaf_sizes,rows,errs)
assign_leaf_points_out_batch<<<1024, 32>>>(tptree->leafs, tptree->leaf_points, tptree->node_ids, rows, tptree->num_nodes - tptree->num_leaves, 0, rows);
CUDA_CHECK(cudaDeviceSynchronize());
// Check that points were correctly assigned to leaf nodes
int* h_leaf_points = new int[rows];
CUDA_CHECK(cudaMemcpy(h_leaf_points, tptree->leaf_points, rows*sizeof(int), cudaMemcpyDeviceToHost));
for(int j=0; j<rows; ++j) {
CHECK_VAL_LT(h_leaf_points[j],tptree->num_leaves,errs)
}
CUDA_CHECK(cudaFree(d_data));
CUDA_CHECK(cudaFree(d_ps));
CUDA_CHECK(cudaFree(states));
tptree->destroy();
return errs;
}
int GPUBuildTPTTest() {
int errors = 0;
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Starting TPTree Kernel tests\n");
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "Float datatype...\n");
errors += TPTKernelsTest<float, float, 100>(1000);
errors += TPTKernelsTest<float, float, 200>(1000);
errors += TPTKernelsTest<float, float, 384>(1000);
errors += TPTKernelsTest<float, float, 1024>(1000);
// SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "int32 datatype...\n");
// errors += TPTKernelsTest<int, int, 100>(1000);
// errors += TPTKernelsTest<int, int, 200>(1000);
// errors += TPTKernelsTest<int, int, 384>(1000);
// errors += TPTKernelsTest<int, int, 1024>(1000);
SPTAGLIB_LOG(SPTAG::Helper::LogLevel::LL_Info, "int8 datatype...\n");
errors += TPTKernelsTest<int8_t, int32_t, 100>(1000);
errors += TPTKernelsTest<int8_t, int32_t, 200>(1000);
errors += TPTKernelsTest<int8_t, int32_t, 384>(1000);
errors += TPTKernelsTest<int8_t, int32_t, 1024>(1000);
return errors;
}